Belle II Software  release-08-01-10
TMVA.h
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 #pragma once
10 #ifndef INCLUDE_GUARD_BELLE2_MVA_TMVA_HEADER
11 #define INCLUDE_GUARD_BELLE2_MVA_TMVA_HEADER
12 
13 #include <mva/interface/Options.h>
14 #include <mva/interface/Teacher.h>
15 #include <mva/interface/Expert.h>
16 
17 #include <TMVA/Factory.h>
18 #include <TMVA/Tools.h>
19 #include <TMVA/Reader.h>
20 #include <TMVA/DataLoader.h>
21 
22 #include <memory>
23 
24 namespace Belle2 {
29  namespace MVA {
30 
34  class TMVAOptions : public SpecificOptions {
35 
36  public:
41  virtual void load(const boost::property_tree::ptree& pt) override;
42 
47  virtual void save(boost::property_tree::ptree& pt) const override;
48 
52  virtual po::options_description getDescription() override;
53 
57  virtual std::string getMethod() const override { return "TMVA"; }
58 
59  public:
60  std::string m_method = "BDT";
61  std::string m_type = "BDT";
66  std::string m_config =
67  "!H:!V:CreateMVAPdfs:NTrees=400:BoostType=Grad:Shrinkage=0.1:UseBaggedBoost:BaggedSampleFraction=0.5:nCuts=1024:MaxDepth=3:IgnoreNegWeightsInTraining";
68  //std::string method = "FastBDT";
69  //std::string type = "Plugins";
70  //std::string config = "!H:!V:CreateMVAPdfs:NTrees=400:Shrinkage=0.10:RandRatio=0.5:NCutLevel=8:NTreeLayers=3";
71  std::string m_factoryOption = "!V:!Silent:Color:DrawProgressBar";
72  std::string m_prepareOption = "SplitMode=random:!V";
73  std::string m_workingDirectory = "";
74  std::string m_prefix = "TMVA";
75  };
76 
81 
82  public:
88  {
89  m_factoryOption += ":AnalysisType=Classification";
90  }
91 
96  virtual void load(const boost::property_tree::ptree& pt) override;
97 
102  virtual void save(boost::property_tree::ptree& pt) const override;
103 
107  virtual po::options_description getDescription() override;
108 
112  virtual std::string getMethod() const override { return "TMVAClassification"; }
113 
114  public:
115  bool transform2probability = true;
116  };
117 
118 
123 
124  public:
130  {
131  m_factoryOption += ":AnalysisType=Multiclass";
132  }
133 
138  virtual void load(const boost::property_tree::ptree& pt) override;
139 
144  virtual void save(boost::property_tree::ptree& pt) const override;
145 
149  virtual po::options_description getDescription() override;
150 
154  virtual std::string getMethod() const override { return "TMVAMulticlass"; }
155 
156  public:
157 
158  std::vector<std::string> m_classes;
160  };
161 
162 
167 
168  public:
174  {
175  m_factoryOption += ":AnalysisType=Regression";
176  }
177 
181  virtual std::string getMethod() const override { return "TMVARegression"; }
182  };
183 
184 
188  class TMVATeacher : public Teacher {
189 
190  public:
196  TMVATeacher(const GeneralOptions& general_options, const TMVAOptions& _specific_options);
197 
204  Weightfile trainFactory(TMVA::Factory& factory, TMVA::DataLoader& data_loader, const std::string& jobName) const;
205 
206  private:
209  };
210 
215 
216  public:
222  TMVATeacherClassification(const GeneralOptions& general_options, const TMVAOptionsClassification& _specific_options);
223 
228  virtual Weightfile train(Dataset& training_data) const override;
229 
230  protected:
232  };
233 
238 
239  public:
245  TMVATeacherMulticlass(const GeneralOptions& general_options, const TMVAOptionsMulticlass& _specific_options);
246 
251  virtual Weightfile train(Dataset& training_data) const override;
252 
253  protected:
255  };
256 
261 
262  public:
268  TMVATeacherRegression(const GeneralOptions& general_options, const TMVAOptionsRegression& _specific_options);
269 
274  virtual Weightfile train(Dataset& training_data) const override;
275 
276  protected:
278  };
279 
280 
284  class TMVAExpert : public MVA::Expert {
285 
286  public:
291  virtual void load(Weightfile& weightfile) override;
292 
293  protected:
294  std::unique_ptr<TMVA::Reader> m_expert;
295  mutable std::vector<float>
297  mutable std::vector<float>
299  };
300 
305 
306  public:
311  virtual void load(Weightfile& weightfile) override;
312 
317  virtual std::vector<float> apply(Dataset& test_data) const override;
318 
319  protected:
323  };
324 
329 
330  public:
335  virtual void load(Weightfile& weightfile) override;
336 
341  virtual std::vector<float> apply(Dataset& test_data) const override
342  {
343  (void) test_data;
344  return std::vector<float>();
345  };
346 
352  virtual std::vector<std::vector<float>> applyMulticlass(Dataset& test_data) const override;
353 
354  protected:
357  };
358 
363 
364  public:
369  virtual void load(Weightfile& weightfile) override;
370 
375  virtual std::vector<float> apply(Dataset& test_data) const override;
376 
377  protected:
380  };
381 
382  }
384 }
385 #endif
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition: Dataset.h:33
Abstract base class of all Expert Each MVA library has its own implementation of this class,...
Definition: Expert.h:31
General options which are shared by all MVA trainings.
Definition: Options.h:62
Specific Options, all method Options have to inherit from this class.
Definition: Options.h:98
Expert for the TMVA Classification MVA method.
Definition: TMVA.h:304
TMVAOptionsClassification specific_options
Method specific options.
Definition: TMVA.h:320
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this m_expert onto a dataset.
Definition: TMVA.cc:507
float expert_signalFraction
Signal fraction used to calculate the probability.
Definition: TMVA.h:321
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: TMVA.cc:411
Expert for the TMVA Multiclass MVA method.
Definition: TMVA.h:328
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this m_expert onto a dataset.
Definition: TMVA.h:341
TMVAOptionsMulticlass specific_options
Method specific options.
Definition: TMVA.h:355
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: TMVA.cc:445
virtual std::vector< std::vector< float > > applyMulticlass(Dataset &test_data) const override
Apply this m_expert onto a dataset.
Definition: TMVA.cc:526
Expert for the TMVA Regression MVA method.
Definition: TMVA.h:362
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this m_expert onto a dataset.
Definition: TMVA.cc:542
TMVAOptionsRegression specific_options
Method specific options.
Definition: TMVA.h:378
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: TMVA.cc:476
Expert for the TMVA MVA method.
Definition: TMVA.h:284
std::vector< float > m_input_cache
Input Cache for TMVA::Reader: Otherwise we would have to set the branch addresses in each apply call.
Definition: TMVA.h:296
std::unique_ptr< TMVA::Reader > m_expert
TMVA::Reader pointer.
Definition: TMVA.h:294
std::vector< float > m_spectators_cache
Spectators Cache for TMVA::Reader: Otherwise we would have to set the branch addresses in each apply ...
Definition: TMVA.h:298
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: TMVA.cc:385
Options for the TMVA Classification MVA method.
Definition: TMVA.h:80
virtual std::string getMethod() const override
Return method name.
Definition: TMVA.h:112
TMVAOptionsClassification()
Constructor Adds Classification as AnalysisType to the factoryOptions.
Definition: TMVA.h:87
virtual po::options_description getDescription() override
Returns a program options description for all available options.
Definition: TMVA.cc:81
bool transform2probability
Transform output of method to a probability.
Definition: TMVA.h:115
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: TMVA.cc:69
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
Definition: TMVA.cc:75
Options for the TMVA Multiclass MVA method.
Definition: TMVA.h:122
virtual std::string getMethod() const override
Return method name.
Definition: TMVA.h:154
TMVAOptionsMulticlass()
Constructor Adds Multiclass as AnalysisType to the factoryOptions.
Definition: TMVA.h:129
std::vector< std::string > m_classes
Class name identifiers.
Definition: TMVA.h:158
virtual po::options_description getDescription() override
Returns a program options description for all available options.
Definition: TMVA.cc:110
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: TMVA.cc:89
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
Definition: TMVA.cc:100
Options for the TMVA Regression MVA method.
Definition: TMVA.h:166
virtual std::string getMethod() const override
Return method name.
Definition: TMVA.h:181
TMVAOptionsRegression()
Constructor Adds REgression as AnalysisType to the factoryOptions.
Definition: TMVA.h:173
Options for the TMVA MVA method.
Definition: TMVA.h:34
virtual std::string getMethod() const override
Return method name.
Definition: TMVA.h:57
std::string m_prepareOption
Prepare options passed to prepareTrainingAndTestTree method.
Definition: TMVA.h:72
std::string m_prefix
Prefix used for all files generated by TMVA.
Definition: TMVA.h:74
std::string m_config
TMVA config string for the chosen method.
Definition: TMVA.h:66
std::string m_method
tmva method name
Definition: TMVA.h:60
virtual po::options_description getDescription() override
Returns a program options description for all available options.
Definition: TMVA.cc:55
std::string m_factoryOption
Factory options passed to tmva factory.
Definition: TMVA.h:71
std::string m_type
tmva method type
Definition: TMVA.h:61
std::string m_workingDirectory
Working directory of TMVA, if empty a temporary directory is used.
Definition: TMVA.h:73
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: TMVA.cc:27
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
Definition: TMVA.cc:43
Teacher for the TMVA Classification MVA method.
Definition: TMVA.h:214
TMVATeacherClassification(const GeneralOptions &general_options, const TMVAOptionsClassification &_specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
Definition: TMVA.cc:203
TMVAOptionsClassification specific_options
Method specific options.
Definition: TMVA.h:231
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
Definition: TMVA.cc:207
Teacher for the TMVA Multiclass MVA method.
Definition: TMVA.h:237
TMVATeacherMulticlass(const GeneralOptions &general_options, const TMVAOptionsMulticlass &_specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
Definition: TMVA.cc:294
TMVAOptionsMulticlass specific_options
Method specific options.
Definition: TMVA.h:254
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
Definition: TMVA.cc:299
Teacher for the TMVA Regression MVA method.
Definition: TMVA.h:260
TMVATeacherRegression(const GeneralOptions &general_options, const TMVAOptionsRegression &_specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
Definition: TMVA.cc:306
TMVAOptionsRegression specific_options
Method specific options.
Definition: TMVA.h:277
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
Definition: TMVA.cc:310
Teacher for the TMVA MVA method.
Definition: TMVA.h:188
TMVATeacher(const GeneralOptions &general_options, const TMVAOptions &_specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
Definition: TMVA.cc:119
Weightfile trainFactory(TMVA::Factory &factory, TMVA::DataLoader &data_loader, const std::string &jobName) const
Train a mva method using the given data loader returning a Weightfile.
Definition: TMVA.cc:122
TMVAOptions specific_options
Method specific options.
Definition: TMVA.h:207
Abstract base class of all Teachers Each MVA library has its own implementation of this class,...
Definition: Teacher.h:29
The Weightfile class serializes all information about a training into an xml tree.
Definition: Weightfile.h:38
Abstract base class for different kinds of events.