Belle II Software development
PyEstimator.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8#include <tracking/trackFindingCDC/mva/PyEstimator.h>
9#include <framework/utilities/FileSystem.h>
10
11#include <boost/python/import.hpp>
12#include <boost/python/extract.hpp>
13#include <boost/python/list.hpp>
14#include <boost/python/tuple.hpp>
15
16#include <framework/logging/Logger.h>
17#include <cstdio>
18
19using namespace Belle2;
20using namespace TrackFindingCDC;
21
22PyEstimator::PyEstimator(const std::string& pickleFileName)
23 : m_pickleFileName(pickleFileName)
24{
25 try {
26 // Construct an array with one entry
27 // Expand it once the number of variables is known.
28 boost::python::object numpy = boost::python::import("numpy");
29 boost::python::list initValues;
30 initValues.append(0.0);
31 m_array = numpy.attr("array")(initValues);
32 m_nCurrent = boost::python::len(m_array);
33 // boost::python::object array = boost::python::import("array");
34 // m_array = array.attr("array")("d");
35 unpickleEstimator(pickleFileName);
36 } catch (const boost::python::error_already_set&) {
37 PyErr_Print();
38 B2ERROR("Could not construct PyEstimator from " << pickleFileName);
39 }
40}
41
42double PyEstimator::predict(const std::vector<double>& inputVariables)
43{
44 size_t nVars = inputVariables.size();
45 expand(nVars);
46
47 for (size_t iVar = 0; iVar < nVars; ++iVar) {
48 m_array[boost::python::make_tuple(0, iVar)] = inputVariables[iVar];
49 }
50 return predict(m_array);
51}
52
53double PyEstimator::predict(const std::vector<NamedFloatTuple*>& floatTuples)
54{
55 size_t nVars = 0;
56 for (NamedFloatTuple* floatTuple : floatTuples) {
57 nVars += floatTuple->size();
58 }
59 expand(nVars);
60 size_t iVar = 0;
61
62 for (NamedFloatTuple* floatTuple : floatTuples) {
63 for (size_t iTuple = 0; iTuple < floatTuple->size(); ++iTuple) {
64 m_array[boost::python::make_tuple(0, iVar)] = floatTuple->get(iTuple);
65 ++iVar;
66 }
67 }
68 return predict(m_array);
69}
70
71double PyEstimator::predict(boost::python::object& array)
72{
73 boost::python::object predictions;
74 try {
75 predictions = m_predict(array);
77 // In case of a binary classification we take the signal probability
78 boost::python::object prediction = predictions[0];
79 return boost::python::extract<double>(prediction[1]);
80 } else {
81 // In case of regression we take the regression value
82 boost::python::object prediction = predictions[0];
83 return boost::python::extract<double>(prediction);
84 }
85 } catch (const boost::python::error_already_set&) {
86 PyErr_Print();
87 B2ERROR("Estimation failed in python object");
88 }
89 return NAN;
90}
91
92void PyEstimator::unpickleEstimator(const std::string& pickleFileName)
93{
94 try {
95 std::string absPickleFilePath = FileSystem::findFile(pickleFileName);
96 boost::python::object io = boost::python::import("io");
97 boost::python::object pickle = boost::python::import("pickle");
98 boost::python::object pickleFile = io.attr("open")(absPickleFilePath, "rb");
99 boost::python::object estimator = pickle.attr("load")(pickleFile);
100 m_estimator = estimator;
101 } catch (const boost::python::error_already_set&) {
102 PyErr_Print();
103 B2ERROR("Could not open pickle file " << pickleFileName);
104 }
105
106 try {
107 m_predict = m_estimator.attr("predict_proba");
109 } catch (const boost::python::error_already_set&) {
110 // AttributeError occurred, but this is allowed to fail
111 // Clear the exception and carry on.
112 PyErr_Clear();
113 B2INFO("Estimator in " << m_pickleFileName << " is not a binary classifier. Trying as regressor");
114 try {
115 m_predict = m_estimator.attr("predict");
117 } catch (const boost::python::error_already_set&) {
118 PyErr_Print();
119 B2ERROR("Could neither find 'predict' not 'predict_proba' in python estimator from file " << pickleFileName);
120 }
121 }
122}
123
124void PyEstimator::expand(size_t nVars)
125{
126 m_nCurrent = boost::python::len(m_array);
127 if (nVars == m_nCurrent) return;
128 try {
129 boost::python::object numpy = boost::python::import("numpy");
130 boost::python::object shape = boost::python::make_tuple(1, nVars); // one sample with nVars columns
131 m_array = numpy.attr("resize")(m_array, shape);
132 } catch (const boost::python::error_already_set&) {
133 PyErr_Print();
134 B2ERROR("Resize failed");
135 }
136}
static std::string findFile(const std::string &path, bool silent=false)
Search for given file or directory in local or central release directory, and return absolute path if...
Definition: FileSystem.cc:151
An abstract tuple of float value where each value has an associated name.
void expand(size_t nVars)
Reserves space for at least n variable in the input array.
Definition: PyEstimator.cc:124
void unpickleEstimator(const std::string &pickleFileName)
Load the estimator object from the pickled file.
Definition: PyEstimator.cc:92
std::string m_pickleFileName
File name of the pickle file that contains the trained estimator.
Definition: PyEstimator.h:48
PyEstimator(const std::string &pickleFileName)
Construct the Estimator.
Definition: PyEstimator.cc:22
double predict(const std::vector< double > &inputVariables)
Call the predict method of the estimator.
Definition: PyEstimator.cc:42
boost::python::object m_estimator
Retrained python estimator object.
Definition: PyEstimator.h:51
boost::python::object m_predict
Python bound prediction method - cached to avoid repeated lookup.
Definition: PyEstimator.h:54
size_t m_nCurrent
Cache for the current length of the input array.
Definition: PyEstimator.h:60
bool m_is_binary_classification
Internal flag to keep track whether a binary classification with predict_proba is evaluated.
Definition: PyEstimator.h:63
boost::python::object m_array
Array to be served to the estimator.
Definition: PyEstimator.h:57
Abstract base class for different kinds of events.