9#include <mva/methods/ONNX.h>
11#include <framework/logging/Logger.h>
15using namespace Belle2::MVA;
20 weightfile.
getFile(
"ONNX_Modelfile", onnxModelFileName);
48 std::vector<float> result;
53 result.push_back(view.outputData()[0]);
66 auto outputs = view.outputData();
68 result[iEvent][iClass] = outputs[iClass];
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
virtual unsigned int getNumberOfEvents() const =0
Returns the number of events in this dataset.
virtual void loadEvent(unsigned int iEvent)=0
Load the event number iEvent.
GeneralOptions m_general_options
General options loaded from the weightfile.
Ort::RunOptions m_runOptions
Options to be passed to Ort::Session::Run.
Ort::Env m_env
Environment object for ONNX session.
const char * m_inputNames[1]
Input tensor names.
const char * m_outputNames[1]
Output tensor names.
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
std::unique_ptr< Ort::Session > m_session
The ONNX inference session.
virtual std::vector< float > apply(Dataset &testData) const override
Apply this expert onto a dataset.
Ort::SessionOptions m_sessionOptions
ONNX session configuration.
virtual std::vector< std::vector< float > > applyMulticlass(Dataset &test_data) const override
Apply this expert onto a dataset and return multiple outputs.
void run(ONNXTensorView &view) const
Run the current inputs through the onnx model Will retrieve and fill the buffers from the view.
View a Dataset's m_input as ONNX Tensor and also set up output buffers/Tensors.
The Weightfile class serializes all information about a training into an xml tree.
void getOptions(Options &options) const
Fills an Option object from the xml tree.
std::string generateFileName(const std::string &suffix="")
Returns a temporary filename with the given suffix.
void getFile(const std::string &identifier, const std::string &custom_weightfile)
Creates a file from our weightfile (mostly this will be a weightfile of an MVA library)