Belle II Software development
ONNX.h
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#pragma once
10
11#include <mva/interface/Expert.h>
12#include <mva/interface/Teacher.h>
13#include <mva/interface/Options.h>
14
15#include <onnxruntime/onnxruntime_cxx_api.h>
16
17namespace Belle2 {
22 namespace MVA {
23 namespace ONNX {
24
31 class BaseTensor {
32 public:
33 virtual ~BaseTensor() {}
34
38 virtual Ort::Value createOrtTensor() = 0;
39 };
40
50 template <typename T>
51 class Tensor : public BaseTensor {
55 std::vector<T> m_values;
56
60 std::vector<int64_t> m_shape;
61
67 Ort::MemoryInfo m_memoryInfo;
68
72 size_t sizeFromShape(const std::vector<int64_t>& shape)
73 {
74 size_t size = 1;
75 for (auto n : shape) size *= n;
76 return size;
77 }
78
85 {
86 for (auto n : m_shape) {
87 if (n < 0) throw std::invalid_argument("All shape dimensions must be positive");
88 }
89 }
90
91 public:
99 Tensor(std::vector<int64_t> shape)
100 : m_values(sizeFromShape(shape)), m_shape(std::move(shape)),
102 Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU))
103 {
105 }
106
107
120 Tensor(std::vector<T> values, std::vector<int64_t> shape)
121 : m_values(std::move(values)), m_shape(std::move(shape)),
123 Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU))
124 {
125 if (sizeFromShape(m_shape) != m_values.size()) {
126 throw std::length_error(
127 "Size of the given values vector (" + std::to_string(m_values.size()) + ") "
128 "does not match the product of the shape dimensions (" + std::to_string(sizeFromShape(m_shape)) + ")");
129 }
131 }
132
145 static auto make_shared(std::vector<int64_t> shape)
146 {
147 return std::make_shared<Tensor>(std::move(shape));
148 }
149
164 static auto make_shared(std::vector<T> values,
165 std::vector<int64_t> shape)
166 {
167 return std::make_shared<Tensor>(std::move(values), std::move(shape));
168 }
169
181 auto& at(size_t index) { return m_values.at(index); }
182
198 auto& at(std::vector<size_t> index)
199 {
200 size_t flat_index = 0;
201 size_t stride = 1;
202 for (int64_t i = m_shape.size() - 1; i >= 0; --i) {
203 if (index[i] >= static_cast<size_t>(m_shape[i])) {
204 throw std::out_of_range(
205 "index " + std::to_string(index[i]) + " is out of bounds for axis "
206 + std::to_string(i) + " with size " + std::to_string(m_shape[i]));
207 }
208 flat_index += index[i] * stride;
209 stride *= m_shape[i];
210 }
211 return at(flat_index);
212 }
213
221 void setValues(const std::vector<T>& values)
222 {
223 if (m_values.size() != values.size()) {
224 throw std::length_error(
225 "Size of new values vector (" + std::to_string(values.size()) + ") "
226 "differs from internal size (" + std::to_string(m_values.size()) + ")");
227 }
228 m_values = values;
229 }
230
249 Ort::Value createOrtTensor()
250 {
251 return Ort::Value::CreateTensor(m_memoryInfo, m_values.data(),
252 m_values.size(), m_shape.data(),
253 m_shape.size());
254 }
255 };
256
257
301 class Session {
302 public:
303
309 Session(const std::string filename);
310
320 void run(const std::map<std::string, std::shared_ptr<BaseTensor>>& inputMap,
321 const std::map<std::string, std::shared_ptr<BaseTensor>>& outputMap);
322
336 void run(const std::vector<const char*>& inputNames,
337 std::vector<Ort::Value>& inputs,
338 const std::vector<const char*>& outputNames,
339 std::vector<Ort::Value>& outputs);
340
346 const Ort::Session& getOrtSession() { return *m_session; }
347
348 private:
352 Ort::Env m_env;
353
357 Ort::SessionOptions m_sessionOptions;
358
362 std::unique_ptr<Ort::Session> m_session;
363
367 Ort::RunOptions m_runOptions;
368 };
369 } // namespace ONNX
370
375
376 public:
380 virtual void load(const boost::property_tree::ptree&) override;
381
385 virtual void save(boost::property_tree::ptree&) const override;
386
390 virtual po::options_description getDescription() override
391 {
392 return po::options_description("ONNX options");
393 }
394
398 virtual std::string getMethod() const override { return "ONNX"; }
399
407 std::string m_outputName;
408 };
409
414 class ONNXTeacher : public Teacher {
415
416 public:
422 ONNXTeacher(const GeneralOptions& general_options,
423 const ONNXOptions&) : Teacher(general_options) {}
424
428 virtual Weightfile train(Dataset&) const override
429 {
430 return Weightfile();
431 }
432 };
433
437 class ONNXExpert : public Expert {
438 public:
443 virtual void load(Weightfile& weightfile) override;
444
449 virtual std::vector<float> apply(Dataset& testData) const override;
450
455 virtual std::vector<std::vector<float>> applyMulticlass(Dataset& test_data) const override;
456
457 private:
462
466 std::unique_ptr<ONNX::Session> m_session;
467
472
476 std::string m_inputName;
477
482 std::string m_outputName;
483 };
484 } // namespace MVA
486} // namespace Belle2
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition Dataset.h:33
Expert()=default
Default constructor.
General options which are shared by all MVA trainings.
Definition Options.h:62
Expert for the ONNX MVA method.
Definition ONNX.h:437
void configureInputOutputNames()
Set up input and output names and perform consistency checks.
Definition ONNX.cc:72
ONNXOptions m_specific_options
ONNX specific options loaded from weightfile.
Definition ONNX.h:471
std::unique_ptr< ONNX::Session > m_session
The ONNX inference session wrapper.
Definition ONNX.h:466
std::string m_outputName
Name of the output tensor (will either be determined automatically or loaded from specific options)
Definition ONNX.h:482
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition ONNX.cc:120
std::string m_inputName
Name of the input tensor (will be determined automatically)
Definition ONNX.h:476
virtual std::vector< float > apply(Dataset &testData) const override
Apply this expert onto a dataset.
Definition ONNX.cc:130
virtual std::vector< std::vector< float > > applyMulticlass(Dataset &test_data) const override
Apply this expert onto a dataset and return multiple outputs.
Definition ONNX.cc:147
Options for the ONNX MVA method.
Definition ONNX.h:374
virtual std::string getMethod() const override
Return method name.
Definition ONNX.h:398
virtual po::options_description getDescription() override
Returns a program options description for all available options.
Definition ONNX.h:390
std::string m_outputName
Name of the output Tensor that is used to make predictions.
Definition ONNX.h:407
virtual void load(const boost::property_tree::ptree &) override
Load mechanism to load Options from a xml tree.
Definition ONNX.cc:62
virtual void save(boost::property_tree::ptree &) const override
Save mechanism to store Options in a xml tree.
Definition ONNX.cc:67
virtual Weightfile train(Dataset &) const override
Just returns a default-initialized weightfile.
Definition ONNX.h:428
ONNXTeacher(const GeneralOptions &general_options, const ONNXOptions &)
Constructs a new teacher using the GeneralOptions and specific options of this training.
Definition ONNX.h:422
Interface class for Tensor template instantiations.
Definition ONNX.h:31
virtual Ort::Value createOrtTensor()=0
Is implemented by Tensor::createOrtTensor.
Ort::RunOptions m_runOptions
Options to be passed to Ort::Session::Run.
Definition ONNX.h:367
Ort::Env m_env
Environment object for ONNX session.
Definition ONNX.h:352
Session(const std::string filename)
Constructs a new ONNX Runtime Session using the specified model file.
Definition ONNX.cc:18
std::unique_ptr< Ort::Session > m_session
The ONNX inference session.
Definition ONNX.h:362
Ort::SessionOptions m_sessionOptions
ONNX session configuration.
Definition ONNX.h:357
const Ort::Session & getOrtSession()
Get a reference to the raw Ort::Session object.
Definition ONNX.h:346
void run(const std::map< std::string, std::shared_ptr< BaseTensor > > &inputMap, const std::map< std::string, std::shared_ptr< BaseTensor > > &outputMap)
Runs inference on the model using named Tensor maps.
Definition ONNX.cc:35
size_t sizeFromShape(const std::vector< int64_t > &shape)
Calculates the internal vector size from the product of the shape dimensions.
Definition ONNX.h:72
Tensor(std::vector< T > values, std::vector< int64_t > shape)
Constructs a tensor from a data vector and shape.
Definition ONNX.h:120
static auto make_shared(std::vector< int64_t > shape)
Convenience method to create a shared pointer to a Tensor from shape.
Definition ONNX.h:145
Tensor(std::vector< int64_t > shape)
Constructs a tensor from shape.
Definition ONNX.h:99
std::vector< int64_t > m_shape
The dimensions of the tensor.
Definition ONNX.h:60
Ort::Value createOrtTensor()
Create an Ort::Value from pointers to the underlying data and shape.
Definition ONNX.h:249
std::vector< T > m_values
Flat buffer storing tensor data in row-major order.
Definition ONNX.h:55
auto & at(std::vector< size_t > index)
Accesses the element at the specified multi-dimensional index.
Definition ONNX.h:198
static auto make_shared(std::vector< T > values, std::vector< int64_t > shape)
Convenience method to create a shared pointer to a Tensor from values and shape.
Definition ONNX.h:164
auto & at(size_t index)
Accesses the element at the specified flat index.
Definition ONNX.h:181
Ort::MemoryInfo m_memoryInfo
Memory information used for allocating ONNX Runtime tensors.
Definition ONNX.h:67
void setValues(const std::vector< T > &values)
Replaces the internal values with a new vector.
Definition ONNX.h:221
void checkShapePositive()
Checks if all shape dimensions are positive.
Definition ONNX.h:84
Specific Options, all method Options have to inherit from this class.
Definition Options.h:98
Teacher(const GeneralOptions &general_options)
Constructs a new teacher using the GeneralOptions for this training.
Definition Teacher.cc:18
The Weightfile class serializes all information about a training into an xml tree.
Definition Weightfile.h:38
Abstract base class for different kinds of events.
STL namespace.