11#include <mva/interface/Expert.h>
12#include <mva/interface/Teacher.h>
13#include <mva/interface/Options.h>
15#include <onnxruntime/onnxruntime_cxx_api.h>
32 virtual void load(
const boost::property_tree::ptree&)
override {}
37 virtual void save(boost::property_tree::ptree&)
const override {}
44 return po::options_description(
"ONNX options");
50 virtual std::string
getMethod()
const override {
return "ONNX"; }
93 OrtDeviceAllocator, OrtMemTypeCPU)),
95 m_memoryInfo, dataset.m_input.data(), dataset.m_input.size(),
161 virtual std::vector<float>
apply(
Dataset& testData)
const override;
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Expert()=default
Default constructor.
General options which are shared by all MVA trainings.
Expert for the ONNX MVA method.
Ort::RunOptions m_runOptions
Options to be passed to Ort::Session::Run.
Ort::Env m_env
Environment object for ONNX session.
const char * m_inputNames[1]
Input tensor names.
const char * m_outputNames[1]
Output tensor names.
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
std::unique_ptr< Ort::Session > m_session
The ONNX inference session.
virtual std::vector< float > apply(Dataset &testData) const override
Apply this expert onto a dataset.
Ort::SessionOptions m_sessionOptions
ONNX session configuration.
virtual std::vector< std::vector< float > > applyMulticlass(Dataset &test_data) const override
Apply this expert onto a dataset and return multiple outputs.
void run(ONNXTensorView &view) const
Run the current inputs through the onnx model Will retrieve and fill the buffers from the view.
Options for the ONNX MVA method.
virtual std::string getMethod() const override
Return method name.
virtual po::options_description getDescription() override
Returns a program options description for all available options.
virtual void load(const boost::property_tree::ptree &) override
Load mechanism to load Options from a xml tree.
virtual void save(boost::property_tree::ptree &) const override
Save mechanism to store Options in a xml tree.
virtual Weightfile train(Dataset &) const override
Just returns a default-initialized weightfile.
ONNXTeacher(const GeneralOptions &general_options, const ONNXOptions &)
Constructs a new teacher using the GeneralOptions and specific options of this training.
View a Dataset's m_input as ONNX Tensor and also set up output buffers/Tensors.
Ort::Value * inputTensor()
Get a pointer to the inputTensor.
std::vector< int64_t > m_inputShape
Shape of the input Tensor.
Ort::Value m_inputTensor
The input Tensor.
std::vector< float > outputData()
Get a vector of output values (that may have been filled)
Ort::Value * outputTensor()
Get a pointer to the outputTensor.
std::vector< int64_t > m_outputShape
Shape of the output Tensor.
ONNXTensorView(Dataset &dataset, int nOutputs)
Construct a new ONNXTensorView.
std::vector< float > m_outputData
Output Tensor buffer.
Ort::MemoryInfo m_memoryInfo
MemoryInfo object to be used when constructing the ONNX Tensors - used to specify device (CPU)
Ort::Value m_outputTensor
The output Tensor.
Specific Options, all method Options have to inherit from this class.
Teacher(const GeneralOptions &general_options)
Constructs a new teacher using the GeneralOptions for this training.
The Weightfile class serializes all information about a training into an xml tree.
Abstract base class for different kinds of events.