Belle II Software development
PyEstimator.h
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8#pragma once
9
10#include <tracking/trackFindingCDC/varsets/NamedFloatTuple.h>
11#include <boost/python/object.hpp>
12
13namespace Belle2 {
18 namespace TrackFindingCDC {
19
20 class NamedFloatTuple;
21
24 public:
28 explicit PyEstimator(const std::string& pickleFileName);
29
31 double predict(const std::vector<double>& inputVariables);
32
34 double predict(const std::vector<NamedFloatTuple*>& floatTuples);
35
37 double predict(boost::python::object& array);
38
39 private:
41 void unpickleEstimator(const std::string& pickleFileName);
42
44 void expand(size_t nVars);
45
46 private:
48 std::string m_pickleFileName;
49
51 boost::python::object m_estimator;
52
54 boost::python::object m_predict;
55
57 boost::python::object m_array;
58
60 size_t m_nCurrent = 0;
61
64
65 };
66
67 }
69}
Class to invoke a pretrained python estimator that follows the sklearn interface.
Definition: PyEstimator.h:23
void expand(size_t nVars)
Reserves space for at least n variable in the input array.
Definition: PyEstimator.cc:124
void unpickleEstimator(const std::string &pickleFileName)
Load the estimator object from the pickled file.
Definition: PyEstimator.cc:92
std::string m_pickleFileName
File name of the pickle file that contains the trained estimator.
Definition: PyEstimator.h:48
double predict(const std::vector< double > &inputVariables)
Call the predict method of the estimator.
Definition: PyEstimator.cc:42
boost::python::object m_estimator
Retrained python estimator object.
Definition: PyEstimator.h:51
boost::python::object m_predict
Python bound prediction method - cached to avoid repeated lookup.
Definition: PyEstimator.h:54
size_t m_nCurrent
Cache for the current length of the input array.
Definition: PyEstimator.h:60
bool m_is_binary_classification
Internal flag to keep track whether a binary classification with predict_proba is evaluated.
Definition: PyEstimator.h:63
boost::python::object m_array
Array to be served to the estimator.
Definition: PyEstimator.h:57
Abstract base class for different kinds of events.