9#include <mva/methods/ONNX.h>
11#include <framework/logging/Logger.h>
15using namespace Belle2::MVA;
16using namespace Belle2::MVA::ONNX;
35void Session::run(
const std::map<std::string, std::shared_ptr<BaseTensor>>& inputMap,
36 const std::map<std::string, std::shared_ptr<BaseTensor>>& outputMap)
38 std::vector<Ort::Value> inputs;
39 std::vector<Ort::Value> outputs;
40 std::vector<const char*> inputNames;
41 std::vector<const char*> outputNames;
42 for (
auto& x : inputMap) {
43 inputNames.push_back(x.first.c_str());
44 inputs.push_back(x.second->createOrtTensor());
46 for (
auto& x : outputMap) {
47 outputNames.push_back(x.first.c_str());
48 outputs.push_back(x.second->createOrtTensor());
50 run(inputNames, inputs, outputNames, outputs);
54 std::vector<Ort::Value>& inputs,
55 const std::vector<const char*>& outputNames,
56 std::vector<Ort::Value>& outputs)
59 outputNames.data(), outputs.data(), outputs.size());
64 m_outputName = pt.get<std::string>(
"ONNX_outputName",
"output");
74 const auto& inputNames =
m_session->getOrtSession().GetInputNames();
75 const auto& outputNames =
m_session->getOrtSession().GetOutputNames();
78 if (inputNames.size() != 1) {
79 std::stringstream msg;
80 msg <<
"Model has multiple inputs: ";
81 for (
auto name : inputNames)
82 msg <<
"\"" << name <<
"\" ";
83 msg <<
"- only single-input models are supported.";
91 if (outputNames.size() == 1) {
93 B2INFO(
"Output name of the model is "
95 <<
" - will use that despite the configured name being \""
107 auto outputFound = std::find(outputNames.begin(), outputNames.end(),
110 std::stringstream msg;
111 msg <<
"No output named \"" <<
m_outputName <<
"\" found. Instead got ";
112 for (
auto name : outputNames)
113 msg <<
"\"" << name <<
"\" ";
114 msg <<
"- either change your model to contain one named \"" <<
m_outputName
115 <<
"\" or set `m_outputName` in the specific options to one of the available names.";
123 weightfile.
getFile(
"ONNX_Modelfile", onnxModelFileName);
126 m_session = std::make_unique<Session>(onnxModelFileName.c_str());
136 std::vector<float> result;
137 result.reserve(nEvents);
138 for (
unsigned int iEvent = 0; iEvent < nEvents; ++iEvent) {
140 input->setValues(testData.
m_input);
142 result.push_back(output->at(0));
154 std::vector<std::vector<float>> result(nEvents, std::vector<float>(nClasses));
155 for (
unsigned int iEvent = 0; iEvent < nEvents; ++iEvent) {
157 input->setValues(testData.
m_input);
159 for (
unsigned int iClass = 0; iClass < nClasses; ++iClass) {
160 result[iEvent][iClass] = output->at(iClass);
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
virtual unsigned int getNumberOfEvents() const =0
Returns the number of events in this dataset.
virtual unsigned int getNumberOfFeatures() const =0
Returns the number of features in this dataset.
virtual void loadEvent(unsigned int iEvent)=0
Load the event number iEvent.
std::vector< float > m_input
Contains all feature values of the currently loaded event.
GeneralOptions m_general_options
General options loaded from the weightfile.
void configureInputOutputNames()
Set up input and output names and perform consistency checks.
ONNXOptions m_specific_options
ONNX specific options loaded from weightfile.
std::unique_ptr< ONNX::Session > m_session
The ONNX inference session wrapper.
std::string m_outputName
Name of the output tensor (will either be determined automatically or loaded from specific options)
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
std::string m_inputName
Name of the input tensor (will be determined automatically)
virtual std::vector< float > apply(Dataset &testData) const override
Apply this expert onto a dataset.
virtual std::vector< std::vector< float > > applyMulticlass(Dataset &test_data) const override
Apply this expert onto a dataset and return multiple outputs.
std::string m_outputName
Name of the output Tensor that is used to make predictions.
virtual void load(const boost::property_tree::ptree &) override
Load mechanism to load Options from a xml tree.
virtual void save(boost::property_tree::ptree &) const override
Save mechanism to store Options in a xml tree.
Ort::RunOptions m_runOptions
Options to be passed to Ort::Session::Run.
Ort::Env m_env
Environment object for ONNX session.
Session(const std::string filename)
Constructs a new ONNX Runtime Session using the specified model file.
std::unique_ptr< Ort::Session > m_session
The ONNX inference session.
Ort::SessionOptions m_sessionOptions
ONNX session configuration.
void run(const std::map< std::string, std::shared_ptr< BaseTensor > > &inputMap, const std::map< std::string, std::shared_ptr< BaseTensor > > &outputMap)
Runs inference on the model using named Tensor maps.
static auto make_shared(std::vector< int64_t > shape)
Convenience method to create a shared pointer to a Tensor from shape.
The Weightfile class serializes all information about a training into an xml tree.
void getOptions(Options &options) const
Fills an Option object from the xml tree.
std::string generateFileName(const std::string &suffix="")
Returns a temporary filename with the given suffix.
void getFile(const std::string &identifier, const std::string &custom_weightfile)
Creates a file from our weightfile (mostly this will be a weightfile of an MVA library)