Belle II Software light-2406-ragdoll
TMVAExpertClassification Class Reference

Expert for the TMVA Classification MVA method. More...

#include <TMVA.h>

Inheritance diagram for TMVAExpertClassification:
Collaboration diagram for TMVAExpertClassification:

Public Member Functions

virtual void load (Weightfile &weightfile) override
 Load the expert from a Weightfile.
 
virtual std::vector< float > apply (Dataset &test_data) const override
 Apply this m_expert onto a dataset.
 
virtual std::vector< std::vector< float > > applyMulticlass (Dataset &test_data) const
 Apply this m_expert onto a dataset.
 

Protected Attributes

TMVAOptionsClassification specific_options
 Method specific options.
 
float expert_signalFraction
 Signal fraction used to calculate the probability.
 
std::unique_ptr< TMVA::Reader > m_expert
 TMVA::Reader pointer.
 
std::vector< float > m_input_cache
 Input Cache for TMVA::Reader: Otherwise we would have to set the branch addresses in each apply call.
 
std::vector< float > m_spectators_cache
 Spectators Cache for TMVA::Reader: Otherwise we would have to set the branch addresses in each apply call.
 
GeneralOptions m_general_options
 General options loaded from the weightfile.
 

Detailed Description

Expert for the TMVA Classification MVA method.

Definition at line 304 of file TMVA.h.

Member Function Documentation

◆ apply()

std::vector< float > apply ( Dataset test_data) const
overridevirtual

Apply this m_expert onto a dataset.

Parameters
test_datadataset

Implements Expert.

Definition at line 507 of file TMVA.cc.

508 {
509
510 std::vector<float> probabilities(test_data.getNumberOfEvents());
511 for (unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
512 test_data.loadEvent(iEvent);
513 for (unsigned int i = 0; i < m_input_cache.size(); ++i)
514 m_input_cache[i] = test_data.m_input[i];
515 for (unsigned int i = 0; i < m_spectators_cache.size(); ++i)
516 m_spectators_cache[i] = test_data.m_spectators[i];
518 probabilities[iEvent] = m_expert->GetProba(specific_options.m_method, expert_signalFraction);
519 else
520 probabilities[iEvent] = m_expert->EvaluateMVA(specific_options.m_method);
521 }
522 return probabilities;
523
524 }
TMVAOptionsClassification specific_options
Method specific options.
Definition: TMVA.h:320
float expert_signalFraction
Signal fraction used to calculate the probability.
Definition: TMVA.h:321
std::vector< float > m_input_cache
Input Cache for TMVA::Reader: Otherwise we would have to set the branch addresses in each apply call.
Definition: TMVA.h:296
std::unique_ptr< TMVA::Reader > m_expert
TMVA::Reader pointer.
Definition: TMVA.h:294
std::vector< float > m_spectators_cache
Spectators Cache for TMVA::Reader: Otherwise we would have to set the branch addresses in each apply ...
Definition: TMVA.h:298
bool transform2probability
Transform output of method to a probability.
Definition: TMVA.h:115
std::string m_method
tmva method name
Definition: TMVA.h:60

◆ applyMulticlass()

virtual std::vector< std::vector< float > > applyMulticlass ( Dataset test_data) const
inlinevirtualinherited

Apply this m_expert onto a dataset.

Multiclass mode. Not pure virtual, since not all derived classes to re-implement this.

Parameters
test_datadataset.
Returns
vector of size N=test_data.getNumberOfEvents() with N=m_classes.size() scores for each event in the dataset.

Reimplemented in PythonExpert, TMVAExpertMulticlass, and TrivialExpert.

Definition at line 56 of file Expert.h.

57 {
58
59 B2ERROR("Attempted to call applyMulticlass() of the abstract base class MVA::Expert. All methods that support multiclass classification should override this definition.");
60 (void) test_data;
61
62 return std::vector<std::vector<float>>();
63 };

◆ load()

void load ( Weightfile weightfile)
overridevirtual

Load the expert from a Weightfile.

Parameters
weightfilecontaining all information necessary to build the m_expert

Reimplemented from TMVAExpert.

Definition at line 411 of file TMVA.cc.

412 {
413
414 weightfile.getOptions(specific_options);
416 expert_signalFraction = weightfile.getSignalFraction();
417 }
418
419 // TMVA parses the method type for plugins out of the weightfile name, so we must ensure that it has the expected format
420 std::string custom_weightfile = weightfile.generateFileName(std::string("_") + specific_options.m_method + ".weights.xml");
421 weightfile.getFile("TMVA_Weightfile", custom_weightfile);
422
423 TMVAExpert::load(weightfile);
424
425 if (specific_options.m_type == "Plugins") {
426 auto base = std::string("TMVA@@MethodBase");
427 auto regexp1 = std::string(".*_") + specific_options.m_method + std::string(".*");
428 auto regexp2 = std::string(".*") + specific_options.m_method + std::string(".*");
429 auto className = std::string("TMVA::Method") + specific_options.m_method;
430 auto ctor1 = std::string("Method") + specific_options.m_method + std::string("(TMVA::DataSetInfo&,TString)");
431 auto ctor2 = std::string("Method") + specific_options.m_method + std::string("(TString&,TString&,TMVA::DataSetInfo&,TString&)");
432 auto pluginName = std::string("TMVA") + specific_options.m_method;
433
434 gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp1.c_str(), className.c_str(), pluginName.c_str(), ctor1.c_str());
435 gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp2.c_str(), className.c_str(), pluginName.c_str(), ctor2.c_str());
436 B2INFO("Registered new TMVA Plugin named " << pluginName);
437 }
438
439 if (!m_expert->BookMVA(specific_options.m_method, custom_weightfile)) {
440 B2FATAL("Could not set up expert! Please see preceding error message from TMVA!");
441 }
442
443 }
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: TMVA.cc:385
std::string m_type
tmva method type
Definition: TMVA.h:61

Member Data Documentation

◆ expert_signalFraction

float expert_signalFraction
protected

Signal fraction used to calculate the probability.

Definition at line 321 of file TMVA.h.

◆ m_expert

std::unique_ptr<TMVA::Reader> m_expert
protectedinherited

TMVA::Reader pointer.

Definition at line 294 of file TMVA.h.

◆ m_general_options

GeneralOptions m_general_options
protectedinherited

General options loaded from the weightfile.

Definition at line 70 of file Expert.h.

◆ m_input_cache

std::vector<float> m_input_cache
mutableprotectedinherited

Input Cache for TMVA::Reader: Otherwise we would have to set the branch addresses in each apply call.

Definition at line 296 of file TMVA.h.

◆ m_spectators_cache

std::vector<float> m_spectators_cache
mutableprotectedinherited

Spectators Cache for TMVA::Reader: Otherwise we would have to set the branch addresses in each apply call.

Definition at line 298 of file TMVA.h.

◆ specific_options

TMVAOptionsClassification specific_options
protected

Method specific options.

Definition at line 320 of file TMVA.h.


The documentation for this class was generated from the following files: