Belle II Software light-2406-ragdoll
TMVATeacherClassification Class Reference

Teacher for the TMVA Classification MVA method. More...

#include <TMVA.h>

Inheritance diagram for TMVATeacherClassification:
Collaboration diagram for TMVATeacherClassification:

Public Member Functions

 TMVATeacherClassification (const GeneralOptions &general_options, const TMVAOptionsClassification &_specific_options)
 Constructs a new teacher using the GeneralOptions and specific options of this training.
 
virtual Weightfile train (Dataset &training_data) const override
 Train a mva method using the given dataset returning a Weightfile.
 
Weightfile trainFactory (TMVA::Factory &factory, TMVA::DataLoader &data_loader, const std::string &jobName) const
 Train a mva method using the given data loader returning a Weightfile.
 

Protected Attributes

TMVAOptionsClassification specific_options
 Method specific options.
 
GeneralOptions m_general_options
 GeneralOptions containing all shared options.
 

Detailed Description

Teacher for the TMVA Classification MVA method.

Definition at line 214 of file TMVA.h.

Constructor & Destructor Documentation

◆ TMVATeacherClassification()

TMVATeacherClassification ( const GeneralOptions general_options,
const TMVAOptionsClassification _specific_options 
)

Constructs a new teacher using the GeneralOptions and specific options of this training.

Parameters
general_optionsdefining all shared options
_specific_optionsdefining all method specific options

Definition at line 203 of file TMVA.cc.

204 : TMVATeacher(general_options, _specific_options),
205 specific_options(_specific_options) { }
TMVAOptionsClassification specific_options
Method specific options.
Definition: TMVA.h:231
TMVATeacher(const GeneralOptions &general_options, const TMVAOptions &_specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
Definition: TMVA.cc:119

Member Function Documentation

◆ train()

Weightfile train ( Dataset training_data) const
overridevirtual

Train a mva method using the given dataset returning a Weightfile.

Parameters
training_dataused to train the method

Implements Teacher.

Definition at line 207 of file TMVA.cc.

208 {
209
210 unsigned int numberOfFeatures = training_data.getNumberOfFeatures();
211 unsigned int numberOfSpectators = training_data.getNumberOfSpectators();
212 unsigned int numberOfEvents = training_data.getNumberOfEvents();
213
214 std::string directory = specific_options.m_workingDirectory;
216 char* directory_template = strdup((std::filesystem::temp_directory_path() / "Basf2TMVA.XXXXXX").c_str());
217 directory = mkdtemp(directory_template);
218 free(directory_template);
219 }
220
221 // cppcheck-suppress unreadVariable
222 auto guard = ScopeGuard::guardWorkingDirectory(directory);
223
224 std::string jobName = specific_options.m_prefix;
225 if (jobName.empty())
226 jobName = "TMVA";
227 TFile classFile((jobName + ".root").c_str(), "RECREATE");
228 classFile.cd();
229
230 TMVA::Tools::Instance();
231 TMVA::DataLoader data_loader(jobName);
232 TMVA::Factory factory(jobName, &classFile, specific_options.m_factoryOption);
233
234
235 // Add variables to the factory
236 for (auto& var : m_general_options.m_variables) {
237 data_loader.AddVariable(Belle2::MakeROOTCompatible::makeROOTCompatible(var));
238 }
239
240 // Add variables to the factory
241 for (auto& var : m_general_options.m_spectators) {
242 data_loader.AddSpectator(Belle2::MakeROOTCompatible::makeROOTCompatible(var));
243 }
244
246
247 auto* signal_tree = new TTree("signal_tree", "signal_tree");
248 auto* background_tree = new TTree("background_tree", "background_tree");
249
250 for (unsigned int iFeature = 0; iFeature < numberOfFeatures; ++iFeature) {
252 &training_data.m_input[iFeature]);
253 background_tree->Branch(Belle2::MakeROOTCompatible::makeROOTCompatible(m_general_options.m_variables[iFeature]).c_str(),
254 &training_data.m_input[iFeature]);
255 }
256
257 for (unsigned int iSpectator = 0; iSpectator < numberOfSpectators; ++iSpectator) {
259 &training_data.m_spectators[iSpectator]);
260 background_tree->Branch(Belle2::MakeROOTCompatible::makeROOTCompatible(m_general_options.m_spectators[iSpectator]).c_str(),
261 &training_data.m_spectators[iSpectator]);
262 }
263
264 signal_tree->Branch("__weight__", &training_data.m_weight);
265 background_tree->Branch("__weight__", &training_data.m_weight);
266
267 for (unsigned int iEvent = 0; iEvent < numberOfEvents; ++iEvent) {
268 training_data.loadEvent(iEvent);
269 if (training_data.m_isSignal) {
270 signal_tree->Fill();
271 } else {
272 background_tree->Fill();
273 }
274 }
275
276 data_loader.AddSignalTree(signal_tree);
277 data_loader.AddBackgroundTree(background_tree);
278 auto weightfile = trainFactory(factory, data_loader, jobName);
279
280 weightfile.addOptions(specific_options);
281 weightfile.addSignalFraction(training_data.getSignalFraction());
282
283 delete signal_tree;
284 delete background_tree;
285
287 std::filesystem::remove_all(directory);
288 }
289
290 return weightfile;
291
292 }
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
Definition: Options.h:86
std::string m_weight_variable
Weight variable (branch name) defining the weights.
Definition: Options.h:91
std::vector< std::string > m_spectators
Vector of all spectators (branch names) used in the training.
Definition: Options.h:87
std::string m_prefix
Prefix used for all files generated by TMVA.
Definition: TMVA.h:74
std::string m_factoryOption
Factory options passed to tmva factory.
Definition: TMVA.h:71
std::string m_workingDirectory
Working directory of TMVA, if empty a temporary directory is used.
Definition: TMVA.h:73
Weightfile trainFactory(TMVA::Factory &factory, TMVA::DataLoader &data_loader, const std::string &jobName) const
Train a mva method using the given data loader returning a Weightfile.
Definition: TMVA.cc:122
GeneralOptions m_general_options
GeneralOptions containing all shared options.
Definition: Teacher.h:49
static std::string makeROOTCompatible(std::string str)
Remove special characters that ROOT dislikes in branch names, e.g.
static ScopeGuard guardWorkingDirectory()
Create a ScopeGuard of the current working directory.
Definition: ScopeGuard.h:296

◆ trainFactory()

Weightfile trainFactory ( TMVA::Factory &  factory,
TMVA::DataLoader &  data_loader,
const std::string &  jobName 
) const
inherited

Train a mva method using the given data loader returning a Weightfile.

Parameters
factoryused to train the method
data_loaderused to train the method
jobNamename of the TMVA training

Definition at line 122 of file TMVA.cc.

123 {
124 data_loader.PrepareTrainingAndTestTree("", specific_options.m_prepareOption);
125
126 if (specific_options.m_type == "Plugins") {
127 auto base = std::string("TMVA@@MethodBase");
128 auto regexp1 = std::string(".*_") + specific_options.m_method + std::string(".*");
129 auto regexp2 = std::string(".*") + specific_options.m_method + std::string(".*");
130 auto className = std::string("TMVA::Method") + specific_options.m_method;
131 auto ctor1 = std::string("Method") + specific_options.m_method + std::string("(TMVA::DataSetInfo&,TString)");
132 auto ctor2 = std::string("Method") + specific_options.m_method + std::string("(TString&,TString&,TMVA::DataSetInfo&,TString&)");
133 auto pluginName = std::string("TMVA") + specific_options.m_method;
134
135 gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp1.c_str(), className.c_str(), pluginName.c_str(), ctor1.c_str());
136 gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp2.c_str(), className.c_str(), pluginName.c_str(), ctor2.c_str());
137 }
138
139 if (!factory.BookMethod(&data_loader, specific_options.m_type, specific_options.m_method, specific_options.m_config)) {
140 B2ERROR("TMVA Method with name " + specific_options.m_method + " cannot be booked.");
141 }
142
143 Weightfile weightfile;
144 std::string logfilename = weightfile.generateFileName(".log");
145
146 // Pipe stdout into a logfile to get TMVA output, which contains valuable information
147 // which cannot be retrieved otherwise!
148 // Hence we do some black magic here
149 // TODO Using ROOT_VERSION 6.08 this should be possible without this workaround
150 auto logfile = open(logfilename.c_str(), O_WRONLY | O_CREAT | O_TRUNC, 0666);
151 auto saved_stdout = dup(STDOUT_FILENO);
152 dup2(logfile, 1);
153
154 factory.TrainAllMethods();
155 factory.TestAllMethods();
156 factory.EvaluateAllMethods();
157
158 // Reset original output
159 dup2(saved_stdout, STDOUT_FILENO);
160 close(saved_stdout);
161 close(logfile);
162
163
164 weightfile.addOptions(m_general_options);
165 weightfile.addFile("TMVA_Weightfile", std::string("TMVA/weights/") + jobName + "_" + specific_options.m_method + ".weights.xml");
166 weightfile.addFile("TMVA_Logfile", logfilename);
167
168 // We have to parse the TMVA output to get the feature importances, there is no other way currently
169 std::string begin = "Ranking input variables (method specific)";
170 std::string end = "-----------------------------------";
171 std::string line;
172 std::ifstream file(logfilename, std::ios::in);
173 std::map<std::string, float> feature_importances;
174 int state = 0;
175 while (std::getline(file, line)) {
176 if (state == 0 && line.find(begin) != std::string::npos) {
177 state = 1;
178 continue;
179 }
180 if (state >= 1 and state <= 4) {
181 state++;
182 continue;
183 }
184 if (state == 5) {
185 if (line.find(end) != std::string::npos)
186 break;
187 std::vector<std::string> strs;
188 boost::split(strs, line, boost::is_any_of(":"));
189 std::string variable = strs[2];
190 boost::trim(variable);
192 float importance = std::stof(strs[3]);
193 feature_importances[variable] = importance;
194 }
195 }
196 weightfile.addFeatureImportance(feature_importances);
197
198 return weightfile;
199
200 }
std::string m_prepareOption
Prepare options passed to prepareTrainingAndTestTree method.
Definition: TMVA.h:72
std::string m_config
TMVA config string for the chosen method.
Definition: TMVA.h:66
std::string m_method
tmva method name
Definition: TMVA.h:60
std::string m_type
tmva method type
Definition: TMVA.h:61
TMVAOptions specific_options
Method specific options.
Definition: TMVA.h:207
static std::string invertMakeROOTCompatible(std::string str)
Invert makeROOTCompatible operation.

Member Data Documentation

◆ m_general_options

GeneralOptions m_general_options
protectedinherited

GeneralOptions containing all shared options.

Definition at line 49 of file Teacher.h.

◆ specific_options

TMVAOptionsClassification specific_options
protected

Method specific options.

Definition at line 231 of file TMVA.h.


The documentation for this class was generated from the following files: