10 #include <tracking/trackFindingCDC/mva/PyEstimator.h>
11 #include <framework/utilities/FileSystem.h>
13 #include <boost/python/import.hpp>
14 #include <boost/python/extract.hpp>
15 #include <boost/python/list.hpp>
16 #include <boost/python/tuple.hpp>
18 #include <framework/logging/Logger.h>
22 using namespace TrackFindingCDC;
25 : m_pickleFileName(pickleFileName)
29 boost::python::object numpy = boost::python::import(
"numpy");
30 boost::python::list initValues;
31 initValues.append(0.0);
32 m_array = numpy.attr(
"array")(initValues);
37 }
catch (
const boost::python::error_already_set&)
40 B2ERROR(
"Could not construct PyEstimator from " << pickleFileName);
45 size_t nVars = inputVariables.size();
48 for (
size_t iVar = 0; iVar < nVars; ++iVar) {
49 m_array[boost::python::make_tuple(0, iVar)] = inputVariables[iVar];
58 nVars += floatTuple->size();
64 for (
size_t iTuple = 0; iTuple < floatTuple->size(); ++iTuple) {
65 m_array[boost::python::make_tuple(0, iVar)] = floatTuple->get(iTuple);
74 boost::python::object predictions;
79 boost::python::object prediction = predictions[0];
80 return boost::python::extract<double>(prediction[1]);
83 boost::python::object prediction = predictions[0];
84 return boost::python::extract<double>(prediction);
86 }
catch (
const boost::python::error_already_set&) {
88 B2ERROR(
"Estimation failed in python object");
97 boost::python::object io = boost::python::import(
"io");
98 boost::python::object pickle = boost::python::import(
"pickle");
99 boost::python::object pickleFile = io.attr(
"open")(absPickleFilePath,
"rb");
100 boost::python::object estimator = pickle.attr(
"load")(pickleFile);
102 }
catch (
const boost::python::error_already_set&) {
104 B2ERROR(
"Could not open pickle file " << pickleFileName);
110 }
catch (
const boost::python::error_already_set&) {
114 B2INFO(
"Estimator in " <<
m_pickleFileName <<
" is not a binary classifier. Trying as regressor");
118 }
catch (
const boost::python::error_already_set&) {
120 B2ERROR(
"Could neither find 'predict' not 'predict_proba' in python estimator from file " << pickleFileName);
130 boost::python::object numpy = boost::python::import(
"numpy");
131 boost::python::object shape = boost::python::make_tuple(1, nVars);
133 }
catch (
const boost::python::error_already_set&) {
135 B2ERROR(
"Resize failed");