208 {
209
210 unsigned int numberOfFeatures = training_data.getNumberOfFeatures();
211 unsigned int numberOfSpectators = training_data.getNumberOfSpectators();
212 unsigned int numberOfEvents = training_data.getNumberOfEvents();
213
216 char* directory_template = strdup((std::filesystem::temp_directory_path() / "Basf2TMVA.XXXXXX").c_str());
217 directory = mkdtemp(directory_template);
218 free(directory_template);
219 }
220
221
223
225 if (jobName.empty())
226 jobName = "TMVA";
227 TFile classFile((jobName + ".root").c_str(), "RECREATE");
228 classFile.cd();
229
230 TMVA::Tools::Instance();
231 TMVA::DataLoader data_loader(jobName);
233
234
235
238 }
239
240
243 }
244
246
247 auto* signal_tree = new TTree("signal_tree", "signal_tree");
248 auto* background_tree = new TTree("background_tree", "background_tree");
249
250 for (unsigned int iFeature = 0; iFeature < numberOfFeatures; ++iFeature) {
252 &training_data.m_input[iFeature]);
254 &training_data.m_input[iFeature]);
255 }
256
257 for (unsigned int iSpectator = 0; iSpectator < numberOfSpectators; ++iSpectator) {
259 &training_data.m_spectators[iSpectator]);
261 &training_data.m_spectators[iSpectator]);
262 }
263
264 signal_tree->Branch("__weight__", &training_data.m_weight);
265 background_tree->Branch("__weight__", &training_data.m_weight);
266
267 for (unsigned int iEvent = 0; iEvent < numberOfEvents; ++iEvent) {
268 training_data.loadEvent(iEvent);
269 if (training_data.m_isSignal) {
270 signal_tree->Fill();
271 } else {
272 background_tree->Fill();
273 }
274 }
275
276 data_loader.AddSignalTree(signal_tree);
277 data_loader.AddBackgroundTree(background_tree);
278 auto weightfile =
trainFactory(factory, data_loader, jobName);
279
281 weightfile.addSignalFraction(training_data.getSignalFraction());
282
283 delete signal_tree;
284 delete background_tree;
285
287 std::filesystem::remove_all(directory);
288 }
289
290 return weightfile;
291
292 }
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
std::string m_weight_variable
Weight variable (branch name) defining the weights.
std::vector< std::string > m_spectators
Vector of all spectators (branch names) used in the training.
std::string m_prefix
Prefix used for all files generated by TMVA.
std::string m_factoryOption
Factory options passed to tmva factory.
std::string m_workingDirectory
Working directory of TMVA, if empty a temporary directory is used.
Weightfile trainFactory(TMVA::Factory &factory, TMVA::DataLoader &data_loader, const std::string &jobName) const
Train a mva method using the given data loader returning a Weightfile.
GeneralOptions m_general_options
GeneralOptions containing all shared options.
static std::string makeROOTCompatible(std::string str)
Remove special characters that ROOT dislikes in branch names, e.g.
static ScopeGuard guardWorkingDirectory()
Create a ScopeGuard of the current working directory.