Belle II Software development
ONNX.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include <mva/methods/ONNX.h>
10
11#include <framework/logging/Logger.h>
12#include <iostream>
13#include <vector>
14
15using namespace Belle2::MVA;
16
18{
19 std::string onnxModelFileName = weightfile.generateFileName();
20 weightfile.getFile("ONNX_Modelfile", onnxModelFileName);
21 weightfile.getOptions(m_general_options);
22
23 // Ensure single-threaded execution, see
24 // https://onnxruntime.ai/docs/performance/tune-performance/threading.html
25 //
26 // InterOpNumThreads is probably optional (not used in ORT_SEQUENTIAL mode)
27 // Also, with batch size 1 and ORT_SEQUENTIAL mode, MLP-like models will
28 // always run single threaded, but maybe not e.g. graph networks which can run
29 // in parallel on nodes. Here, setting IntraOpNumThreads to 1 is important to
30 // ensure single-threaded execution.
31 m_sessionOptions.SetIntraOpNumThreads(1);
32 m_sessionOptions.SetInterOpNumThreads(1);
33 m_sessionOptions.SetExecutionMode(ORT_SEQUENTIAL); // default, but make it explicit
34
35 m_session = std::make_unique<Ort::Session>(m_env, onnxModelFileName.c_str(), m_sessionOptions);
36}
37
39{
41 m_inputNames, view.inputTensor(), 1,
42 m_outputNames, view.outputTensor(), 1);
43}
44
45std::vector<float> ONNXExpert::apply(Dataset& testData) const
46{
47 auto view = ONNXTensorView(testData, 1);
48 std::vector<float> result;
49 result.reserve(testData.getNumberOfEvents());
50 for (unsigned int iEvent = 0; iEvent < testData.getNumberOfEvents(); ++iEvent) {
51 testData.loadEvent(iEvent);
52 run(view);
53 result.push_back(view.outputData()[0]);
54 }
55 return result;
56}
57
58std::vector<std::vector<float>> ONNXExpert::applyMulticlass(Dataset& testData) const
59{
60 auto view = ONNXTensorView(testData, m_general_options.m_nClasses);
61 std::vector<std::vector<float>> result(testData.getNumberOfEvents(),
62 std::vector<float>(m_general_options.m_nClasses));
63 for (unsigned int iEvent = 0; iEvent < testData.getNumberOfEvents(); ++iEvent) {
64 testData.loadEvent(iEvent);
65 run(view);
66 auto outputs = view.outputData();
67 for (unsigned int iClass = 0; iClass < m_general_options.m_nClasses; ++iClass) {
68 result[iEvent][iClass] = outputs[iClass];
69 }
70 }
71 return result;
72}
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition Dataset.h:33
virtual unsigned int getNumberOfEvents() const =0
Returns the number of events in this dataset.
virtual void loadEvent(unsigned int iEvent)=0
Load the event number iEvent.
GeneralOptions m_general_options
General options loaded from the weightfile.
Definition Expert.h:70
Ort::RunOptions m_runOptions
Options to be passed to Ort::Session::Run.
Definition ONNX.h:194
Ort::Env m_env
Environment object for ONNX session.
Definition ONNX.h:179
const char * m_inputNames[1]
Input tensor names.
Definition ONNX.h:199
const char * m_outputNames[1]
Output tensor names.
Definition ONNX.h:204
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition ONNX.cc:17
std::unique_ptr< Ort::Session > m_session
The ONNX inference session.
Definition ONNX.h:189
virtual std::vector< float > apply(Dataset &testData) const override
Apply this expert onto a dataset.
Definition ONNX.cc:45
Ort::SessionOptions m_sessionOptions
ONNX session configuration.
Definition ONNX.h:184
virtual std::vector< std::vector< float > > applyMulticlass(Dataset &test_data) const override
Apply this expert onto a dataset and return multiple outputs.
Definition ONNX.cc:58
void run(ONNXTensorView &view) const
Run the current inputs through the onnx model Will retrieve and fill the buffers from the view.
Definition ONNX.cc:38
View a Dataset's m_input as ONNX Tensor and also set up output buffers/Tensors.
Definition ONNX.h:83
The Weightfile class serializes all information about a training into an xml tree.
Definition Weightfile.h:38
void getOptions(Options &options) const
Fills an Option object from the xml tree.
Definition Weightfile.cc:67
std::string generateFileName(const std::string &suffix="")
Returns a temporary filename with the given suffix.
void getFile(const std::string &identifier, const std::string &custom_weightfile)
Creates a file from our weightfile (mostly this will be a weightfile of an MVA library)