9#include <mva/methods/ONNX.h>
11#include <framework/logging/Logger.h>
15using namespace Belle2::MVA;
16using namespace Belle2::MVA::ONNX;
35void Session::run(
const std::map<std::string, std::shared_ptr<BaseTensor>>& inputMap,
36 const std::map<std::string, std::shared_ptr<BaseTensor>>& outputMap)
38 std::vector<Ort::Value> inputs;
39 std::vector<Ort::Value> outputs;
40 std::vector<const char*> inputNames;
41 std::vector<const char*> outputNames;
42 for (
auto& x : inputMap) {
43 inputNames.push_back(x.first.c_str());
44 inputs.push_back(x.second->createOrtTensor());
46 for (
auto& x : outputMap) {
47 outputNames.push_back(x.first.c_str());
48 outputs.push_back(x.second->createOrtTensor());
50 run(inputNames, inputs, outputNames, outputs);
54 std::vector<Ort::Value>& inputs,
55 const std::vector<const char*>& outputNames,
56 std::vector<Ort::Value>& outputs)
59 outputNames.data(), outputs.data(), outputs.size());
64 m_outputName = pt.get<std::string>(
"ONNX_outputName",
"output");
65 m_modelFilename = pt.get<std::string>(
"ONNX_modelFilename",
"model.onnx");
76 B2WARNING(
"The ONNX interface does not perform any training - "
77 "the train method just stores an existing ONNX model into an MVA weightfile.");
79 B2FATAL(
"You have to provide a path to an ONNX model "
80 "via `m_modelFilename` in the specific options");
91 const auto& inputNames =
m_session->getOrtSession().GetInputNames();
92 const auto& outputNames =
m_session->getOrtSession().GetOutputNames();
95 if (inputNames.size() != 1) {
96 std::stringstream msg;
97 msg <<
"Model has multiple inputs: ";
98 for (
auto name : inputNames)
99 msg <<
"\"" << name <<
"\" ";
100 msg <<
"- only single-input models are supported.";
108 if (outputNames.size() == 1) {
110 B2INFO(
"Output name of the model is "
112 <<
" - will use that despite the configured name being \""
124 auto outputFound = std::find(outputNames.begin(), outputNames.end(),
127 std::stringstream msg;
128 msg <<
"No output named \"" <<
m_outputName <<
"\" found. Instead got ";
129 for (
auto name : outputNames)
130 msg <<
"\"" << name <<
"\" ";
131 msg <<
"- either change your model to contain one named \"" <<
m_outputName
132 <<
"\" or set `m_outputName` in the specific options to one of the available names.";
140 for (
auto name :
m_session->getOrtSession().GetOutputNames()) {
145 auto typeInfo =
m_session->getOrtSession().GetOutputTypeInfo(tensorIndex);
146 auto shape = typeInfo.GetTensorTypeAndShapeInfo().GetShape();
147 if (shape.back() == 2) {
159 std::string onnxModelFileName = weightfile.generateFileName();
160 weightfile.getFile(
"ONNX_Modelfile", onnxModelFileName);
163 m_session = std::make_unique<Session>(onnxModelFileName.c_str());
175 std::vector<float> result;
176 result.reserve(nEvents);
177 for (
unsigned int iEvent = 0; iEvent < nEvents; ++iEvent) {
179 input->setValues(testData.
m_input);
193 std::vector<std::vector<float>> result(nEvents, std::vector<float>(nClasses));
194 for (
unsigned int iEvent = 0; iEvent < nEvents; ++iEvent) {
196 input->setValues(testData.
m_input);
198 for (
unsigned int iClass = 0; iClass < nClasses; ++iClass) {
199 result[iEvent][iClass] = output->at(iClass);
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
virtual unsigned int getNumberOfEvents() const =0
Returns the number of events in this dataset.
virtual unsigned int getNumberOfFeatures() const =0
Returns the number of features in this dataset.
virtual void loadEvent(unsigned int iEvent)=0
Load the event number iEvent.
std::vector< float > m_input
Contains all feature values of the currently loaded event.
GeneralOptions m_general_options
General options loaded from the weightfile.
void configureInputOutputNames()
Set up input and output names and perform consistency checks.
ONNXOptions m_specific_options
ONNX specific options loaded from weightfile.
std::unique_ptr< ONNX::Session > m_session
The ONNX inference session wrapper.
std::string m_outputName
Name of the output tensor (will either be determined automatically or loaded from specific options)
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
std::string m_inputName
Name of the input tensor (will be determined automatically)
virtual std::vector< float > apply(Dataset &testData) const override
Apply this expert onto a dataset.
int m_outputValueIndex
Index of the output value to pick in non-multiclass mode.
virtual std::vector< std::vector< float > > applyMulticlass(Dataset &test_data) const override
Apply this expert onto a dataset and return multiple outputs.
void configureOutputValueIndex()
Configure index of the value to be used for the configured output tensor.
std::string m_modelFilename
Filename of the model.
std::string m_outputName
Name of the output Tensor that is used to make predictions.
virtual void load(const boost::property_tree::ptree &) override
Load mechanism to load Options from a xml tree.
virtual void save(boost::property_tree::ptree &) const override
Save mechanism to store Options in a xml tree.
ONNXOptions m_specific_options
Method specific options.
virtual Weightfile train(Dataset &) const override
Won't do any actual training, but will return a valid MVA Weightfile.
Ort::RunOptions m_runOptions
Options to be passed to Ort::Session::Run.
Ort::Env m_env
Environment object for ONNX session.
Session(const std::string filename)
Constructs a new ONNX Runtime Session using the specified model file.
std::unique_ptr< Ort::Session > m_session
The ONNX inference session.
Ort::SessionOptions m_sessionOptions
ONNX session configuration.
void run(const std::map< std::string, std::shared_ptr< BaseTensor > > &inputMap, const std::map< std::string, std::shared_ptr< BaseTensor > > &outputMap)
Runs inference on the model using named Tensor maps.
static auto make_shared(std::vector< int64_t > shape)
Convenience method to create a shared pointer to a Tensor from shape.
GeneralOptions m_general_options
GeneralOptions containing all shared options.
The Weightfile class serializes all information about a training into an xml tree.