Belle II Software development
TMVA.h
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#pragma once
10#ifndef INCLUDE_GUARD_BELLE2_MVA_TMVA_HEADER
11#define INCLUDE_GUARD_BELLE2_MVA_TMVA_HEADER
12
13#include <mva/interface/Options.h>
14#include <mva/interface/Teacher.h>
15#include <mva/interface/Expert.h>
16
17#include <TMVA/Factory.h>
18#include <TMVA/Tools.h>
19#include <TMVA/Reader.h>
20#include <TMVA/DataLoader.h>
21
22#include <memory>
23
24namespace Belle2 {
29 namespace MVA {
30
35
36 public:
41 virtual void load(const boost::property_tree::ptree& pt) override;
42
47 virtual void save(boost::property_tree::ptree& pt) const override;
48
52 virtual po::options_description getDescription() override;
53
57 virtual std::string getMethod() const override { return "TMVA"; }
58
59 public:
60 std::string m_method = "BDT";
61 std::string m_type = "BDT";
66 std::string m_config =
67 "!H:!V:CreateMVAPdfs:NTrees=400:BoostType=Grad:Shrinkage=0.1:UseBaggedBoost:BaggedSampleFraction=0.5:nCuts=1024:MaxDepth=3:IgnoreNegWeightsInTraining";
68 //std::string method = "FastBDT";
69 //std::string type = "Plugins";
70 //std::string config = "!H:!V:CreateMVAPdfs:NTrees=400:Shrinkage=0.10:RandRatio=0.5:NCutLevel=8:NTreeLayers=3";
71 std::string m_factoryOption = "!V:!Silent:Color:DrawProgressBar";
72 std::string m_prepareOption = "SplitMode=random:!V";
73 std::string m_workingDirectory = "";
74 std::string m_prefix = "TMVA";
75 };
76
81
82 public:
88 {
89 m_factoryOption += ":AnalysisType=Classification";
90 }
91
96 virtual void load(const boost::property_tree::ptree& pt) override;
97
102 virtual void save(boost::property_tree::ptree& pt) const override;
103
107 virtual po::options_description getDescription() override;
108
112 virtual std::string getMethod() const override { return "TMVAClassification"; }
113
114 public:
116 };
117
118
123
124 public:
130 {
131 m_factoryOption += ":AnalysisType=Multiclass";
132 }
133
138 virtual void load(const boost::property_tree::ptree& pt) override;
139
144 virtual void save(boost::property_tree::ptree& pt) const override;
145
149 virtual po::options_description getDescription() override;
150
154 virtual std::string getMethod() const override { return "TMVAMulticlass"; }
155
156 public:
157
158 std::vector<std::string> m_classes;
160 };
161
162
167
168 public:
174 {
175 m_factoryOption += ":AnalysisType=Regression";
176 }
177
181 virtual std::string getMethod() const override { return "TMVARegression"; }
182 };
183
184
188 class TMVATeacher : public Teacher {
189
190 public:
196 TMVATeacher(const GeneralOptions& general_options, const TMVAOptions& _specific_options);
197
204 Weightfile trainFactory(TMVA::Factory& factory, TMVA::DataLoader& data_loader, const std::string& jobName) const;
205
206 private:
209 };
210
215
216 public:
222 TMVATeacherClassification(const GeneralOptions& general_options, const TMVAOptionsClassification& _specific_options);
223
228 virtual Weightfile train(Dataset& training_data) const override;
229
230 protected:
232 };
233
238
239 public:
245 TMVATeacherMulticlass(const GeneralOptions& general_options, const TMVAOptionsMulticlass& _specific_options);
246
251 virtual Weightfile train(Dataset& training_data) const override;
252
253 protected:
255 };
256
261
262 public:
268 TMVATeacherRegression(const GeneralOptions& general_options, const TMVAOptionsRegression& _specific_options);
269
274 virtual Weightfile train(Dataset& training_data) const override;
275
276 protected:
278 };
279
280
284 class TMVAExpert : public MVA::Expert {
285
286 public:
291 virtual void load(Weightfile& weightfile) override;
292
293 protected:
294 std::unique_ptr<TMVA::Reader> m_expert;
295 mutable std::vector<float>
297 mutable std::vector<float>
299 };
300
305
306 public:
311 virtual void load(Weightfile& weightfile) override;
312
317 virtual std::vector<float> apply(Dataset& test_data) const override;
318
319 protected:
323 };
324
329
330 public:
335 virtual void load(Weightfile& weightfile) override;
336
341 virtual std::vector<float> apply(Dataset& test_data) const override
342 {
343 (void) test_data;
344 return std::vector<float>();
345 };
346
352 virtual std::vector<std::vector<float>> applyMulticlass(Dataset& test_data) const override;
353
354 protected:
357 };
358
363
364 public:
369 virtual void load(Weightfile& weightfile) override;
370
375 virtual std::vector<float> apply(Dataset& test_data) const override;
376
377 protected:
380 };
381
382 }
384}
385#endif
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition: Dataset.h:33
Abstract base class of all Expert Each MVA library has its own implementation of this class,...
Definition: Expert.h:31
General options which are shared by all MVA trainings.
Definition: Options.h:62
Specific Options, all method Options have to inherit from this class.
Definition: Options.h:98
Expert for the TMVA Classification MVA method.
Definition: TMVA.h:304
TMVAOptionsClassification specific_options
Method specific options.
Definition: TMVA.h:320
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this m_expert onto a dataset.
Definition: TMVA.cc:507
float expert_signalFraction
Signal fraction used to calculate the probability.
Definition: TMVA.h:321
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: TMVA.cc:411
Expert for the TMVA Multiclass MVA method.
Definition: TMVA.h:328
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this m_expert onto a dataset.
Definition: TMVA.h:341
TMVAOptionsMulticlass specific_options
Method specific options.
Definition: TMVA.h:355
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: TMVA.cc:445
virtual std::vector< std::vector< float > > applyMulticlass(Dataset &test_data) const override
Apply this m_expert onto a dataset.
Definition: TMVA.cc:526
Expert for the TMVA Regression MVA method.
Definition: TMVA.h:362
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this m_expert onto a dataset.
Definition: TMVA.cc:542
TMVAOptionsRegression specific_options
Method specific options.
Definition: TMVA.h:378
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: TMVA.cc:476
Expert for the TMVA MVA method.
Definition: TMVA.h:284
std::vector< float > m_input_cache
Input Cache for TMVA::Reader: Otherwise we would have to set the branch addresses in each apply call.
Definition: TMVA.h:296
std::unique_ptr< TMVA::Reader > m_expert
TMVA::Reader pointer.
Definition: TMVA.h:294
std::vector< float > m_spectators_cache
Spectators Cache for TMVA::Reader: Otherwise we would have to set the branch addresses in each apply ...
Definition: TMVA.h:298
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: TMVA.cc:385
Options for the TMVA Classification MVA method.
Definition: TMVA.h:80
virtual std::string getMethod() const override
Return method name.
Definition: TMVA.h:112
TMVAOptionsClassification()
Constructor Adds Classification as AnalysisType to the factoryOptions.
Definition: TMVA.h:87
virtual po::options_description getDescription() override
Returns a program options description for all available options.
Definition: TMVA.cc:81
bool transform2probability
Transform output of method to a probability.
Definition: TMVA.h:115
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: TMVA.cc:69
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
Definition: TMVA.cc:75
Options for the TMVA Multiclass MVA method.
Definition: TMVA.h:122
virtual std::string getMethod() const override
Return method name.
Definition: TMVA.h:154
TMVAOptionsMulticlass()
Constructor Adds Multiclass as AnalysisType to the factoryOptions.
Definition: TMVA.h:129
std::vector< std::string > m_classes
Class name identifiers.
Definition: TMVA.h:158
virtual po::options_description getDescription() override
Returns a program options description for all available options.
Definition: TMVA.cc:110
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: TMVA.cc:89
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
Definition: TMVA.cc:100
Options for the TMVA Regression MVA method.
Definition: TMVA.h:166
virtual std::string getMethod() const override
Return method name.
Definition: TMVA.h:181
TMVAOptionsRegression()
Constructor Adds REgression as AnalysisType to the factoryOptions.
Definition: TMVA.h:173
Options for the TMVA MVA method.
Definition: TMVA.h:34
virtual std::string getMethod() const override
Return method name.
Definition: TMVA.h:57
std::string m_prepareOption
Prepare options passed to prepareTrainingAndTestTree method.
Definition: TMVA.h:72
std::string m_prefix
Prefix used for all files generated by TMVA.
Definition: TMVA.h:74
std::string m_config
TMVA config string for the chosen method.
Definition: TMVA.h:66
std::string m_method
tmva method name
Definition: TMVA.h:60
virtual po::options_description getDescription() override
Returns a program options description for all available options.
Definition: TMVA.cc:55
std::string m_factoryOption
Factory options passed to tmva factory.
Definition: TMVA.h:71
std::string m_type
tmva method type
Definition: TMVA.h:61
std::string m_workingDirectory
Working directory of TMVA, if empty a temporary directory is used.
Definition: TMVA.h:73
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: TMVA.cc:27
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
Definition: TMVA.cc:43
Teacher for the TMVA Classification MVA method.
Definition: TMVA.h:214
TMVAOptionsClassification specific_options
Method specific options.
Definition: TMVA.h:231
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
Definition: TMVA.cc:207
Teacher for the TMVA Multiclass MVA method.
Definition: TMVA.h:237
TMVAOptionsMulticlass specific_options
Method specific options.
Definition: TMVA.h:254
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
Definition: TMVA.cc:299
Teacher for the TMVA Regression MVA method.
Definition: TMVA.h:260
TMVAOptionsRegression specific_options
Method specific options.
Definition: TMVA.h:277
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
Definition: TMVA.cc:310
Teacher for the TMVA MVA method.
Definition: TMVA.h:188
Weightfile trainFactory(TMVA::Factory &factory, TMVA::DataLoader &data_loader, const std::string &jobName) const
Train a mva method using the given data loader returning a Weightfile.
Definition: TMVA.cc:122
TMVAOptions specific_options
Method specific options.
Definition: TMVA.h:207
Abstract base class of all Teachers Each MVA library has its own implementation of this class,...
Definition: Teacher.h:29
The Weightfile class serializes all information about a training into an xml tree.
Definition: Weightfile.h:38
Abstract base class for different kinds of events.