Belle II Software development
ONNXExpert Class Reference

Expert for the ONNX MVA method. More...

#include <ONNX.h>

Inheritance diagram for ONNXExpert:
Expert

Public Member Functions

virtual void load (Weightfile &weightfile) override
 Load the expert from a Weightfile.
 
virtual std::vector< float > apply (Dataset &testData) const override
 Apply this expert onto a dataset.
 
virtual std::vector< std::vector< float > > applyMulticlass (Dataset &test_data) const override
 Apply this expert onto a dataset and return multiple outputs.
 

Protected Attributes

GeneralOptions m_general_options
 General options loaded from the weightfile.
 

Private Member Functions

void configureInputOutputNames ()
 Set up input and output names and perform consistency checks.
 

Private Attributes

std::unique_ptr< ONNX::Sessionm_session
 The ONNX inference session wrapper.
 
ONNXOptions m_specific_options
 ONNX specific options loaded from weightfile.
 
std::string m_inputName
 Name of the input tensor (will be determined automatically)
 
std::string m_outputName
 Name of the output tensor (will either be determined automatically or loaded from specific options)
 

Detailed Description

Expert for the ONNX MVA method.

Definition at line 437 of file ONNX.h.

Member Function Documentation

◆ apply()

std::vector< float > apply ( Dataset & testData) const
overridevirtual

Apply this expert onto a dataset.

Parameters
testDatadataset

Implements Expert.

Definition at line 130 of file ONNX.cc.

131{
132 auto nFeatures = testData.getNumberOfFeatures();
133 auto nEvents = testData.getNumberOfEvents();
134 auto input = Tensor<float>::make_shared({1, nFeatures});
135 auto output = Tensor<float>::make_shared({1, 1});
136 std::vector<float> result;
137 result.reserve(nEvents);
138 for (unsigned int iEvent = 0; iEvent < nEvents; ++iEvent) {
139 testData.loadEvent(iEvent);
140 input->setValues(testData.m_input);
141 m_session->run({{m_inputName, input}}, {{m_outputName, output}});
142 result.push_back(output->at(0));
143 }
144 return result;
145}
virtual unsigned int getNumberOfEvents() const =0
Returns the number of events in this dataset.
virtual unsigned int getNumberOfFeatures() const =0
Returns the number of features in this dataset.
virtual void loadEvent(unsigned int iEvent)=0
Load the event number iEvent.
std::vector< float > m_input
Contains all feature values of the currently loaded event.
Definition Dataset.h:123
std::unique_ptr< ONNX::Session > m_session
The ONNX inference session wrapper.
Definition ONNX.h:466
std::string m_outputName
Name of the output tensor (will either be determined automatically or loaded from specific options)
Definition ONNX.h:482
std::string m_inputName
Name of the input tensor (will be determined automatically)
Definition ONNX.h:476
static auto make_shared(std::vector< int64_t > shape)
Convenience method to create a shared pointer to a Tensor from shape.
Definition ONNX.h:145

◆ applyMulticlass()

std::vector< std::vector< float > > applyMulticlass ( Dataset & test_data) const
overridevirtual

Apply this expert onto a dataset and return multiple outputs.

Parameters
test_datadataset

Reimplemented from Expert.

Definition at line 147 of file ONNX.cc.

148{
149 unsigned int nClasses = m_general_options.m_nClasses;
150 auto nFeatures = testData.getNumberOfFeatures();
151 auto nEvents = testData.getNumberOfEvents();
152 auto input = Tensor<float>::make_shared({1, nFeatures});
153 auto output = Tensor<float>::make_shared({1, nClasses});
154 std::vector<std::vector<float>> result(nEvents, std::vector<float>(nClasses));
155 for (unsigned int iEvent = 0; iEvent < nEvents; ++iEvent) {
156 testData.loadEvent(iEvent);
157 input->setValues(testData.m_input);
158 m_session->run({{m_inputName, input}}, {{m_outputName, output}});
159 for (unsigned int iClass = 0; iClass < nClasses; ++iClass) {
160 result[iEvent][iClass] = output->at(iClass);
161 }
162 }
163 return result;
164}
GeneralOptions m_general_options
General options loaded from the weightfile.
Definition Expert.h:70
unsigned int m_nClasses
Number of classes in a classification problem.
Definition Options.h:89

◆ configureInputOutputNames()

void configureInputOutputNames ( )
private

Set up input and output names and perform consistency checks.

Definition at line 72 of file ONNX.cc.

73{
74 const auto& inputNames = m_session->getOrtSession().GetInputNames();
75 const auto& outputNames = m_session->getOrtSession().GetOutputNames();
76
77 // Check if we have a single input model and set the input name to that
78 if (inputNames.size() != 1) {
79 std::stringstream msg;
80 msg << "Model has multiple inputs: ";
81 for (auto name : inputNames)
82 msg << "\"" << name << "\" ";
83 msg << "- only single-input models are supported.";
84 B2FATAL(msg.str());
85 }
86 m_inputName = inputNames[0];
87
88 m_outputName = m_specific_options.m_outputName;
89
90 // For single-output models we just take the name of that single output
91 if (outputNames.size() == 1) {
92 if (!m_outputName.empty() && m_outputName != outputNames[0]) {
93 B2INFO("Output name of the model is "
94 << outputNames[0]
95 << " - will use that despite the configured name being \""
96 << m_outputName << "\"");
97 }
98 m_outputName = outputNames[0];
99 return;
100 }
101
102 // Otherwise we have a multiple-output model and need to check if the
103 // configured output name, or the fallback value "output", exists
104 if (m_outputName.empty()) {
105 m_outputName = "output";
106 }
107 auto outputFound = std::find(outputNames.begin(), outputNames.end(),
108 m_outputName) != outputNames.end();
109 if (!outputFound) {
110 std::stringstream msg;
111 msg << "No output named \"" << m_outputName << "\" found. Instead got ";
112 for (auto name : outputNames)
113 msg << "\"" << name << "\" ";
114 msg << "- either change your model to contain one named \"" << m_outputName
115 << "\" or set `m_outputName` in the specific options to one of the available names.";
116 B2FATAL(msg.str());
117 }
118}
ONNXOptions m_specific_options
ONNX specific options loaded from weightfile.
Definition ONNX.h:471

◆ load()

void load ( Weightfile & weightfile)
overridevirtual

Load the expert from a Weightfile.

Parameters
weightfilecontaining all information necessary to build the expert

Implements Expert.

Definition at line 120 of file ONNX.cc.

121{
122 std::string onnxModelFileName = weightfile.generateFileName();
123 weightfile.getFile("ONNX_Modelfile", onnxModelFileName);
124 weightfile.getOptions(m_general_options);
125 weightfile.getOptions(m_specific_options);
126 m_session = std::make_unique<Session>(onnxModelFileName.c_str());
128}
void configureInputOutputNames()
Set up input and output names and perform consistency checks.
Definition ONNX.cc:72
void getOptions(Options &options) const
Fills an Option object from the xml tree.
Definition Weightfile.cc:67
std::string generateFileName(const std::string &suffix="")
Returns a temporary filename with the given suffix.
void getFile(const std::string &identifier, const std::string &custom_weightfile)
Creates a file from our weightfile (mostly this will be a weightfile of an MVA library)

Member Data Documentation

◆ m_general_options

GeneralOptions m_general_options
protectedinherited

General options loaded from the weightfile.

Definition at line 70 of file Expert.h.

◆ m_inputName

std::string m_inputName
private

Name of the input tensor (will be determined automatically)

Definition at line 476 of file ONNX.h.

◆ m_outputName

std::string m_outputName
private

Name of the output tensor (will either be determined automatically or loaded from specific options)

Definition at line 482 of file ONNX.h.

◆ m_session

std::unique_ptr<ONNX::Session> m_session
private

The ONNX inference session wrapper.

Definition at line 466 of file ONNX.h.

◆ m_specific_options

ONNXOptions m_specific_options
private

ONNX specific options loaded from weightfile.

Definition at line 471 of file ONNX.h.


The documentation for this class was generated from the following files: