Belle II Software prerelease-10-00-00a
ONNX.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include <mva/methods/ONNX.h>
10
11#include <framework/logging/Logger.h>
12#include <iostream>
13#include <vector>
14
15using namespace Belle2::MVA;
16
18{
19 std::string onnxModelFileName = weightfile.generateFileName();
20 weightfile.getFile("ONNX_Modelfile", onnxModelFileName);
21
22 // Ensure single-threaded execution, see
23 // https://onnxruntime.ai/docs/performance/tune-performance/threading.html
24 //
25 // InterOpNumThreads is probably optional (not used in ORT_SEQUENTIAL mode)
26 // Also, with batch size 1 and ORT_SEQUENTIAL mode, MLP-like models will
27 // always run single threaded, but maybe not e.g. graph networks which can run
28 // in parallel on nodes. Here, setting IntraOpNumThreads to 1 is important to
29 // ensure single-threaded execution.
30 m_sessionOptions.SetIntraOpNumThreads(1);
31 m_sessionOptions.SetInterOpNumThreads(1);
32 m_sessionOptions.SetExecutionMode(ORT_SEQUENTIAL); // default, but make it explicit
33
34 m_session = std::make_unique<Ort::Session>(m_env, onnxModelFileName.c_str(), m_sessionOptions);
35}
36
38{
40 m_inputNames, view.inputTensor(), 1,
41 m_outputNames, view.outputTensor(), 1);
42}
43
44std::vector<float> ONNXExpert::apply(Dataset& testData) const
45{
46 auto view = ONNXTensorView(testData, 1);
47 std::vector<float> result;
48 result.reserve(testData.getNumberOfEvents());
49 for (unsigned int iEvent = 0; iEvent < testData.getNumberOfEvents(); ++iEvent) {
50 testData.loadEvent(iEvent);
51 run(view);
52 result.push_back(view.outputData()[0]);
53 }
54 return result;
55}
56
57std::vector<std::vector<float>> ONNXExpert::applyMulticlass(Dataset& testData) const
58{
59 auto view = ONNXTensorView(testData, m_general_options.m_nClasses);
60 std::vector<std::vector<float>> result(testData.getNumberOfEvents(),
61 std::vector<float>(m_general_options.m_nClasses));
62 for (unsigned int iEvent = 0; iEvent < testData.getNumberOfEvents(); ++iEvent) {
63 testData.loadEvent(iEvent);
64 run(view);
65 auto outputs = view.outputData();
66 for (unsigned int iClass = 0; iClass < m_general_options.m_nClasses; ++iClass) {
67 result[iEvent][iClass] = outputs[iClass];
68 }
69 }
70 return result;
71}
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition Dataset.h:33
virtual unsigned int getNumberOfEvents() const =0
Returns the number of events in this dataset.
virtual void loadEvent(unsigned int iEvent)=0
Load the event number iEvent.
GeneralOptions m_general_options
General options loaded from the weightfile.
Definition Expert.h:70
Ort::RunOptions m_runOptions
Options to be passed to Ort::Session::Run.
Definition ONNX.h:194
Ort::Env m_env
Environment object for ONNX session.
Definition ONNX.h:179
const char * m_inputNames[1]
Input tensor names.
Definition ONNX.h:199
const char * m_outputNames[1]
Output tensor names.
Definition ONNX.h:204
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition ONNX.cc:17
std::unique_ptr< Ort::Session > m_session
The ONNX inference session.
Definition ONNX.h:189
virtual std::vector< float > apply(Dataset &testData) const override
Apply this expert onto a dataset.
Definition ONNX.cc:44
Ort::SessionOptions m_sessionOptions
ONNX session configuration.
Definition ONNX.h:184
virtual std::vector< std::vector< float > > applyMulticlass(Dataset &test_data) const override
Apply this expert onto a dataset and return multiple outputs.
Definition ONNX.cc:57
void run(ONNXTensorView &view) const
Run the current inputs through the onnx model Will retrieve and fill the buffers from the view.
Definition ONNX.cc:37
View a Dataset's m_input as ONNX Tensor and also set up output buffers/Tensors.
Definition ONNX.h:83
The Weightfile class serializes all information about a training into an xml tree.
Definition Weightfile.h:38
std::string generateFileName(const std::string &suffix="")
Returns a temporary filename with the given suffix.
void getFile(const std::string &identifier, const std::string &custom_weightfile)
Creates a file from our weightfile (mostly this will be a weightfile of an MVA library)