Belle II Software light-2406-ragdoll
TMVATeacherRegression Class Reference

Teacher for the TMVA Regression MVA method. More...

#include <TMVA.h>

Inheritance diagram for TMVATeacherRegression:
Collaboration diagram for TMVATeacherRegression:

Public Member Functions

 TMVATeacherRegression (const GeneralOptions &general_options, const TMVAOptionsRegression &_specific_options)
 Constructs a new teacher using the GeneralOptions and specific options of this training.
 
virtual Weightfile train (Dataset &training_data) const override
 Train a mva method using the given dataset returning a Weightfile.
 
Weightfile trainFactory (TMVA::Factory &factory, TMVA::DataLoader &data_loader, const std::string &jobName) const
 Train a mva method using the given data loader returning a Weightfile.
 

Protected Attributes

TMVAOptionsRegression specific_options
 Method specific options.
 
GeneralOptions m_general_options
 GeneralOptions containing all shared options.
 

Detailed Description

Teacher for the TMVA Regression MVA method.

Definition at line 260 of file TMVA.h.

Constructor & Destructor Documentation

◆ TMVATeacherRegression()

TMVATeacherRegression ( const GeneralOptions general_options,
const TMVAOptionsRegression _specific_options 
)

Constructs a new teacher using the GeneralOptions and specific options of this training.

Parameters
general_optionsdefining all shared options
_specific_optionsdefining all method specific options

Definition at line 306 of file TMVA.cc.

307 : TMVATeacher(general_options, _specific_options),
308 specific_options(_specific_options) { }
TMVAOptionsRegression specific_options
Method specific options.
Definition: TMVA.h:277
TMVATeacher(const GeneralOptions &general_options, const TMVAOptions &_specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
Definition: TMVA.cc:119

Member Function Documentation

◆ train()

Weightfile train ( Dataset training_data) const
overridevirtual

Train a mva method using the given dataset returning a Weightfile.

Parameters
training_dataused to train the method

Implements Teacher.

Definition at line 310 of file TMVA.cc.

311 {
312
313 unsigned int numberOfFeatures = training_data.getNumberOfFeatures();
314 unsigned int numberOfSpectators = training_data.getNumberOfSpectators();
315 unsigned int numberOfEvents = training_data.getNumberOfEvents();
316
317 std::string directory = specific_options.m_workingDirectory;
319 char* directory_template = strdup((std::filesystem::temp_directory_path() / "Basf2TMVA.XXXXXX").c_str());
320 directory = mkdtemp(directory_template);
321 free(directory_template);
322 }
323
324 // cppcheck-suppress unreadVariable
325 auto guard = ScopeGuard::guardWorkingDirectory(directory);
326
327 std::string jobName = specific_options.m_prefix;
328 if (jobName.empty())
329 jobName = "TMVA";
330 TFile classFile((jobName + ".root").c_str(), "RECREATE");
331 classFile.cd();
332
333 TMVA::Tools::Instance();
334 TMVA::DataLoader data_loader(jobName);
335 TMVA::Factory factory(jobName, &classFile, specific_options.m_factoryOption);
336
337 // Add variables to the factory
338 for (auto& var : m_general_options.m_variables) {
339 data_loader.AddVariable(Belle2::MakeROOTCompatible::makeROOTCompatible(var));
340 }
341
342 // Add variables to the factory
343 for (auto& var : m_general_options.m_spectators) {
344 data_loader.AddSpectator(Belle2::MakeROOTCompatible::makeROOTCompatible(var));
345 }
346
348
349 auto* regression_tree = new TTree("regression_tree", "regression_tree");
350
351 for (unsigned int iFeature = 0; iFeature < numberOfFeatures; ++iFeature) {
352 regression_tree->Branch(Belle2::MakeROOTCompatible::makeROOTCompatible(m_general_options.m_variables[iFeature]).c_str(),
353 &training_data.m_input[iFeature]);
354 }
355 for (unsigned int iSpectator = 0; iSpectator < numberOfSpectators; ++iSpectator) {
356 regression_tree->Branch(Belle2::MakeROOTCompatible::makeROOTCompatible(m_general_options.m_spectators[iSpectator]).c_str(),
357 &training_data.m_spectators[iSpectator]);
358 }
360 &training_data.m_target);
361
362 regression_tree->Branch("__weight__", &training_data.m_weight);
363
364 for (unsigned int iEvent = 0; iEvent < numberOfEvents; ++iEvent) {
365 training_data.loadEvent(iEvent);
366 regression_tree->Fill();
367 }
368
369 data_loader.AddRegressionTree(regression_tree);
370 data_loader.SetWeightExpression(Belle2::MakeROOTCompatible::makeROOTCompatible(m_general_options.m_weight_variable), "Regression");
371
372 auto weightfile = trainFactory(factory, data_loader, jobName);
373 weightfile.addOptions(specific_options);
374
375 delete regression_tree;
376
378 std::filesystem::remove_all(directory);
379 }
380
381 return weightfile;
382
383 }
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
Definition: Options.h:86
std::string m_weight_variable
Weight variable (branch name) defining the weights.
Definition: Options.h:91
std::vector< std::string > m_spectators
Vector of all spectators (branch names) used in the training.
Definition: Options.h:87
std::string m_target_variable
Target variable (branch name) defining the target.
Definition: Options.h:90
std::string m_prefix
Prefix used for all files generated by TMVA.
Definition: TMVA.h:74
std::string m_factoryOption
Factory options passed to tmva factory.
Definition: TMVA.h:71
std::string m_workingDirectory
Working directory of TMVA, if empty a temporary directory is used.
Definition: TMVA.h:73
Weightfile trainFactory(TMVA::Factory &factory, TMVA::DataLoader &data_loader, const std::string &jobName) const
Train a mva method using the given data loader returning a Weightfile.
Definition: TMVA.cc:122
GeneralOptions m_general_options
GeneralOptions containing all shared options.
Definition: Teacher.h:49
static std::string makeROOTCompatible(std::string str)
Remove special characters that ROOT dislikes in branch names, e.g.
static ScopeGuard guardWorkingDirectory()
Create a ScopeGuard of the current working directory.
Definition: ScopeGuard.h:296

◆ trainFactory()

Weightfile trainFactory ( TMVA::Factory &  factory,
TMVA::DataLoader &  data_loader,
const std::string &  jobName 
) const
inherited

Train a mva method using the given data loader returning a Weightfile.

Parameters
factoryused to train the method
data_loaderused to train the method
jobNamename of the TMVA training

Definition at line 122 of file TMVA.cc.

123 {
124 data_loader.PrepareTrainingAndTestTree("", specific_options.m_prepareOption);
125
126 if (specific_options.m_type == "Plugins") {
127 auto base = std::string("TMVA@@MethodBase");
128 auto regexp1 = std::string(".*_") + specific_options.m_method + std::string(".*");
129 auto regexp2 = std::string(".*") + specific_options.m_method + std::string(".*");
130 auto className = std::string("TMVA::Method") + specific_options.m_method;
131 auto ctor1 = std::string("Method") + specific_options.m_method + std::string("(TMVA::DataSetInfo&,TString)");
132 auto ctor2 = std::string("Method") + specific_options.m_method + std::string("(TString&,TString&,TMVA::DataSetInfo&,TString&)");
133 auto pluginName = std::string("TMVA") + specific_options.m_method;
134
135 gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp1.c_str(), className.c_str(), pluginName.c_str(), ctor1.c_str());
136 gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp2.c_str(), className.c_str(), pluginName.c_str(), ctor2.c_str());
137 }
138
139 if (!factory.BookMethod(&data_loader, specific_options.m_type, specific_options.m_method, specific_options.m_config)) {
140 B2ERROR("TMVA Method with name " + specific_options.m_method + " cannot be booked.");
141 }
142
143 Weightfile weightfile;
144 std::string logfilename = weightfile.generateFileName(".log");
145
146 // Pipe stdout into a logfile to get TMVA output, which contains valuable information
147 // which cannot be retrieved otherwise!
148 // Hence we do some black magic here
149 // TODO Using ROOT_VERSION 6.08 this should be possible without this workaround
150 auto logfile = open(logfilename.c_str(), O_WRONLY | O_CREAT | O_TRUNC, 0666);
151 auto saved_stdout = dup(STDOUT_FILENO);
152 dup2(logfile, 1);
153
154 factory.TrainAllMethods();
155 factory.TestAllMethods();
156 factory.EvaluateAllMethods();
157
158 // Reset original output
159 dup2(saved_stdout, STDOUT_FILENO);
160 close(saved_stdout);
161 close(logfile);
162
163
164 weightfile.addOptions(m_general_options);
165 weightfile.addFile("TMVA_Weightfile", std::string("TMVA/weights/") + jobName + "_" + specific_options.m_method + ".weights.xml");
166 weightfile.addFile("TMVA_Logfile", logfilename);
167
168 // We have to parse the TMVA output to get the feature importances, there is no other way currently
169 std::string begin = "Ranking input variables (method specific)";
170 std::string end = "-----------------------------------";
171 std::string line;
172 std::ifstream file(logfilename, std::ios::in);
173 std::map<std::string, float> feature_importances;
174 int state = 0;
175 while (std::getline(file, line)) {
176 if (state == 0 && line.find(begin) != std::string::npos) {
177 state = 1;
178 continue;
179 }
180 if (state >= 1 and state <= 4) {
181 state++;
182 continue;
183 }
184 if (state == 5) {
185 if (line.find(end) != std::string::npos)
186 break;
187 std::vector<std::string> strs;
188 boost::split(strs, line, boost::is_any_of(":"));
189 std::string variable = strs[2];
190 boost::trim(variable);
192 float importance = std::stof(strs[3]);
193 feature_importances[variable] = importance;
194 }
195 }
196 weightfile.addFeatureImportance(feature_importances);
197
198 return weightfile;
199
200 }
std::string m_prepareOption
Prepare options passed to prepareTrainingAndTestTree method.
Definition: TMVA.h:72
std::string m_config
TMVA config string for the chosen method.
Definition: TMVA.h:66
std::string m_method
tmva method name
Definition: TMVA.h:60
std::string m_type
tmva method type
Definition: TMVA.h:61
TMVAOptions specific_options
Method specific options.
Definition: TMVA.h:207
static std::string invertMakeROOTCompatible(std::string str)
Invert makeROOTCompatible operation.

Member Data Documentation

◆ m_general_options

GeneralOptions m_general_options
protectedinherited

GeneralOptions containing all shared options.

Definition at line 49 of file Teacher.h.

◆ specific_options

TMVAOptionsRegression specific_options
protected

Method specific options.

Definition at line 277 of file TMVA.h.


The documentation for this class was generated from the following files: