11#include <mva/interface/Expert.h>
12#include <mva/interface/Teacher.h>
13#include <mva/interface/Options.h>
15#include <onnxruntime/onnxruntime_cxx_api.h>
75 for (
auto n : shape) size *= n;
87 if (n < 0)
throw std::invalid_argument(
"All shape dimensions must be positive");
102 Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU))
120 Tensor(std::vector<T> values, std::vector<int64_t> shape)
123 Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU))
126 throw std::length_error(
127 "Size of the given values vector (" + std::to_string(m_values.size()) +
") "
128 "does not match the product of the shape dimensions (" + std::to_string(sizeFromShape(m_shape)) +
")");
147 return std::make_shared<Tensor>(std::move(shape));
165 std::vector<int64_t> shape)
167 return std::make_shared<Tensor>(std::move(values), std::move(shape));
198 auto&
at(std::vector<size_t> index)
200 size_t flat_index = 0;
202 for (int64_t i =
m_shape.size() - 1; i >= 0; --i) {
203 if (index[i] >=
static_cast<size_t>(
m_shape[i])) {
204 throw std::out_of_range(
205 "index " + std::to_string(index[i]) +
" is out of bounds for axis "
206 + std::to_string(i) +
" with size " + std::to_string(
m_shape[i]));
208 flat_index += index[i] * stride;
211 return at(flat_index);
223 if (
m_values.size() != values.size()) {
224 throw std::length_error(
225 "Size of new values vector (" + std::to_string(values.size()) +
") "
226 "differs from internal size (" + std::to_string(
m_values.size()) +
")");
309 Session(
const std::string filename);
320 void run(
const std::map<std::string, std::shared_ptr<BaseTensor>>& inputMap,
321 const std::map<std::string, std::shared_ptr<BaseTensor>>& outputMap);
336 void run(
const std::vector<const char*>& inputNames,
337 std::vector<Ort::Value>& inputs,
338 const std::vector<const char*>& outputNames,
339 std::vector<Ort::Value>& outputs);
380 virtual void load(
const boost::property_tree::ptree&)
override;
385 virtual void save(boost::property_tree::ptree&)
const override;
392 return po::options_description(
"ONNX options");
398 virtual std::string
getMethod()
const override {
return "ONNX"; }
449 virtual std::vector<float>
apply(
Dataset& testData)
const override;
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Expert()=default
Default constructor.
General options which are shared by all MVA trainings.
Expert for the ONNX MVA method.
void configureInputOutputNames()
Set up input and output names and perform consistency checks.
ONNXOptions m_specific_options
ONNX specific options loaded from weightfile.
std::unique_ptr< ONNX::Session > m_session
The ONNX inference session wrapper.
std::string m_outputName
Name of the output tensor (will either be determined automatically or loaded from specific options)
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
std::string m_inputName
Name of the input tensor (will be determined automatically)
virtual std::vector< float > apply(Dataset &testData) const override
Apply this expert onto a dataset.
virtual std::vector< std::vector< float > > applyMulticlass(Dataset &test_data) const override
Apply this expert onto a dataset and return multiple outputs.
Options for the ONNX MVA method.
virtual std::string getMethod() const override
Return method name.
virtual po::options_description getDescription() override
Returns a program options description for all available options.
std::string m_outputName
Name of the output Tensor that is used to make predictions.
virtual void load(const boost::property_tree::ptree &) override
Load mechanism to load Options from a xml tree.
virtual void save(boost::property_tree::ptree &) const override
Save mechanism to store Options in a xml tree.
virtual Weightfile train(Dataset &) const override
Just returns a default-initialized weightfile.
ONNXTeacher(const GeneralOptions &general_options, const ONNXOptions &)
Constructs a new teacher using the GeneralOptions and specific options of this training.
Interface class for Tensor template instantiations.
virtual Ort::Value createOrtTensor()=0
Is implemented by Tensor::createOrtTensor.
Ort::RunOptions m_runOptions
Options to be passed to Ort::Session::Run.
Ort::Env m_env
Environment object for ONNX session.
Session(const std::string filename)
Constructs a new ONNX Runtime Session using the specified model file.
std::unique_ptr< Ort::Session > m_session
The ONNX inference session.
Ort::SessionOptions m_sessionOptions
ONNX session configuration.
const Ort::Session & getOrtSession()
Get a reference to the raw Ort::Session object.
void run(const std::map< std::string, std::shared_ptr< BaseTensor > > &inputMap, const std::map< std::string, std::shared_ptr< BaseTensor > > &outputMap)
Runs inference on the model using named Tensor maps.
size_t sizeFromShape(const std::vector< int64_t > &shape)
Calculates the internal vector size from the product of the shape dimensions.
Tensor(std::vector< T > values, std::vector< int64_t > shape)
Constructs a tensor from a data vector and shape.
static auto make_shared(std::vector< int64_t > shape)
Convenience method to create a shared pointer to a Tensor from shape.
Tensor(std::vector< int64_t > shape)
Constructs a tensor from shape.
std::vector< int64_t > m_shape
The dimensions of the tensor.
Ort::Value createOrtTensor()
Create an Ort::Value from pointers to the underlying data and shape.
std::vector< T > m_values
Flat buffer storing tensor data in row-major order.
auto & at(std::vector< size_t > index)
Accesses the element at the specified multi-dimensional index.
static auto make_shared(std::vector< T > values, std::vector< int64_t > shape)
Convenience method to create a shared pointer to a Tensor from values and shape.
auto & at(size_t index)
Accesses the element at the specified flat index.
Ort::MemoryInfo m_memoryInfo
Memory information used for allocating ONNX Runtime tensors.
void setValues(const std::vector< T > &values)
Replaces the internal values with a new vector.
void checkShapePositive()
Checks if all shape dimensions are positive.
Specific Options, all method Options have to inherit from this class.
Teacher(const GeneralOptions &general_options)
Constructs a new teacher using the GeneralOptions for this training.
The Weightfile class serializes all information about a training into an xml tree.
Abstract base class for different kinds of events.