Belle II Software development
ONNX.h
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#pragma once
10
11#include <mva/interface/Expert.h>
12#include <mva/interface/Teacher.h>
13#include <mva/interface/Options.h>
14
15#include <onnxruntime/onnxruntime_cxx_api.h>
16
17namespace Belle2 {
22 namespace MVA {
27
28 public:
32 virtual void load(const boost::property_tree::ptree&) override {}
33
37 virtual void save(boost::property_tree::ptree&) const override {}
38
42 virtual po::options_description getDescription() override
43 {
44 return po::options_description("ONNX options");
45 }
46
50 virtual std::string getMethod() const override { return "ONNX"; }
51 };
52
57 class ONNXTeacher : public Teacher {
58
59 public:
65 ONNXTeacher(const GeneralOptions& general_options,
66 const ONNXOptions&) : Teacher(general_options) {}
67
71 virtual Weightfile train(Dataset&) const override
72 {
73 return Weightfile();
74 }
75 };
76
84 public:
90 ONNXTensorView(Dataset& dataset, int nOutputs)
91 : m_inputShape{1, dataset.getNumberOfFeatures()}, m_outputData(nOutputs),
92 m_outputShape{1, nOutputs}, m_memoryInfo(Ort::MemoryInfo::CreateCpu(
93 OrtDeviceAllocator, OrtMemTypeCPU)),
94 m_inputTensor(Ort::Value::CreateTensor<float>(
95 m_memoryInfo, dataset.m_input.data(), dataset.m_input.size(),
96 m_inputShape.data(), m_inputShape.size())),
97 m_outputTensor(Ort::Value::CreateTensor<float>(
99 m_outputShape.data(), m_outputShape.size())) {}
100
103 Ort::Value* inputTensor() { return &m_inputTensor; }
104
108 Ort::Value* outputTensor() { return &m_outputTensor; }
109
113 std::vector<float> outputData() { return m_outputData; }
114 private:
118 std::vector<int64_t> m_inputShape;
119
123 std::vector<float> m_outputData;
124
128 std::vector<int64_t> m_outputShape;
129
133 Ort::MemoryInfo m_memoryInfo;
134
138 Ort::Value m_inputTensor;
139
143 Ort::Value m_outputTensor;
144 };
145
149 class ONNXExpert : public Expert {
150 public:
155 virtual void load(Weightfile& weightfile) override;
156
161 virtual std::vector<float> apply(Dataset& testData) const override;
162
167 virtual std::vector<std::vector<float>> applyMulticlass(Dataset& test_data) const override;
168
169 private:
174 void run(ONNXTensorView& view) const;
175
179 Ort::Env m_env;
180
184 Ort::SessionOptions m_sessionOptions;
185
189 std::unique_ptr<Ort::Session> m_session;
190
194 Ort::RunOptions m_runOptions;
195
199 const char* m_inputNames[1] = {"input"};
200
204 const char* m_outputNames[1] = {"output"};
205 };
206 } // namespace MVA
208} // namespace Belle2
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition Dataset.h:33
Expert()=default
Default constructor.
General options which are shared by all MVA trainings.
Definition Options.h:62
Expert for the ONNX MVA method.
Definition ONNX.h:149
Ort::RunOptions m_runOptions
Options to be passed to Ort::Session::Run.
Definition ONNX.h:194
Ort::Env m_env
Environment object for ONNX session.
Definition ONNX.h:179
const char * m_inputNames[1]
Input tensor names.
Definition ONNX.h:199
const char * m_outputNames[1]
Output tensor names.
Definition ONNX.h:204
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition ONNX.cc:17
std::unique_ptr< Ort::Session > m_session
The ONNX inference session.
Definition ONNX.h:189
virtual std::vector< float > apply(Dataset &testData) const override
Apply this expert onto a dataset.
Definition ONNX.cc:45
Ort::SessionOptions m_sessionOptions
ONNX session configuration.
Definition ONNX.h:184
virtual std::vector< std::vector< float > > applyMulticlass(Dataset &test_data) const override
Apply this expert onto a dataset and return multiple outputs.
Definition ONNX.cc:58
void run(ONNXTensorView &view) const
Run the current inputs through the onnx model Will retrieve and fill the buffers from the view.
Definition ONNX.cc:38
Options for the ONNX MVA method.
Definition ONNX.h:26
virtual std::string getMethod() const override
Return method name.
Definition ONNX.h:50
virtual po::options_description getDescription() override
Returns a program options description for all available options.
Definition ONNX.h:42
virtual void load(const boost::property_tree::ptree &) override
Load mechanism to load Options from a xml tree.
Definition ONNX.h:32
virtual void save(boost::property_tree::ptree &) const override
Save mechanism to store Options in a xml tree.
Definition ONNX.h:37
virtual Weightfile train(Dataset &) const override
Just returns a default-initialized weightfile.
Definition ONNX.h:71
ONNXTeacher(const GeneralOptions &general_options, const ONNXOptions &)
Constructs a new teacher using the GeneralOptions and specific options of this training.
Definition ONNX.h:65
View a Dataset's m_input as ONNX Tensor and also set up output buffers/Tensors.
Definition ONNX.h:83
Ort::Value * inputTensor()
Get a pointer to the inputTensor.
Definition ONNX.h:103
std::vector< int64_t > m_inputShape
Shape of the input Tensor.
Definition ONNX.h:118
Ort::Value m_inputTensor
The input Tensor.
Definition ONNX.h:138
std::vector< float > outputData()
Get a vector of output values (that may have been filled)
Definition ONNX.h:113
Ort::Value * outputTensor()
Get a pointer to the outputTensor.
Definition ONNX.h:108
std::vector< int64_t > m_outputShape
Shape of the output Tensor.
Definition ONNX.h:128
ONNXTensorView(Dataset &dataset, int nOutputs)
Construct a new ONNXTensorView.
Definition ONNX.h:90
std::vector< float > m_outputData
Output Tensor buffer.
Definition ONNX.h:123
Ort::MemoryInfo m_memoryInfo
MemoryInfo object to be used when constructing the ONNX Tensors - used to specify device (CPU)
Definition ONNX.h:133
Ort::Value m_outputTensor
The output Tensor.
Definition ONNX.h:143
Specific Options, all method Options have to inherit from this class.
Definition Options.h:98
Teacher(const GeneralOptions &general_options)
Constructs a new teacher using the GeneralOptions for this training.
Definition Teacher.cc:18
The Weightfile class serializes all information about a training into an xml tree.
Definition Weightfile.h:38
Abstract base class for different kinds of events.