Belle II Software development
FastBDTClassifierAnalyzerModule.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include <fstream>
10#include <TFile.h>
11#include <TTree.h>
12
13#include <tracking/modules/vxdtfRedesign/FastBDTClassifierAnalyzerModule.h>
14#include <tracking/spacePointCreation/MapHelperFunctions.h>
15
16using namespace Belle2;
17
18REG_MODULE(FastBDTClassifierAnalyzer);
19
21{
23 "analyzes performance of given FastBDT on a test and a training set and determines a global classification cut. TODO");
24
25 addParam("fbdtFileName", m_PARAMfbdtFileName, "file name of the fbdtclassifier");
26 addParam("trainSamples", m_PARAMtrainSampleFileName, "filename of the training samples");
27 addParam("testSamples", m_PARAMtestSampleFileName, "filename of the test samples");
28 addParam("outputFileName", m_PARAMrootOutFileName, "output filename", std::string("FBDTAnalyzer_out.root"));
29}
30
32{
33 std::ifstream fbdt(m_PARAMfbdtFileName);
34 if (!fbdt.is_open()) {
35 B2ERROR("Could not open file: " << m_PARAMfbdtFileName << ".");
36 }
37
38 std::ifstream train(m_PARAMtrainSampleFileName);
39 if (!train.is_open()) {
40 B2ERROR("Could not open file: " << m_PARAMtrainSampleFileName << ".");
41 }
42
43 std::ifstream test(m_PARAMtestSampleFileName);
44 if (!test.is_open()) {
45 B2ERROR("Could not open file: " << m_PARAMtestSampleFileName << ".");
46 }
47
48 B2DEBUG(20, "Reading Classifier from file: " << m_PARAMfbdtFileName << ".");
50 fbdt.close();
51 B2DEBUG(20, "Done");
52
53 B2DEBUG(20, "Reading training samples from file: " << m_PARAMtrainSampleFileName << ".");
55 train.close();
56
57 B2DEBUG(20, "Reading training samples from file: " << m_PARAMtestSampleFileName << ".");
59 test.close();
60}
61
63{
64 std::ofstream ofs("analyze_trout.dat");
65 B2DEBUG(21, "Processing the training sample");
66 for (const auto& event : m_trainSample) {
67 m_trainOutput.insert(std::make_pair(event.signal, m_classifier.analyze(event.hits)));
68 ofs << event.signal << " " << m_classifier.analyze(event.hits) << std::endl;
69 }
70 ofs.close();
71
72 B2DEBUG(21, "Processing the test sample");
73 for (const auto& event : m_testSample) {
74 m_testOutput.insert(std::make_pair(event.signal, m_classifier.analyze(event.hits)));
75 }
76
77 auto trainBgOut = getValuesToKey(m_trainOutput, 0);
78 auto trainSigOut = getValuesToKey(m_trainOutput, 1);
79
80 auto testBgOut = getValuesToKey(m_testOutput, 0);
81 auto testSigOut = getValuesToKey(m_testOutput, 1);
82
83 TFile* outfile = new TFile(m_PARAMrootOutFileName.c_str(), "RECREATE");
84 TTree* tree = new TTree("classifierOutputs", "outputs of FBDTClassifier for the different samples");
85 tree->Branch("train_bg_outputs", &trainBgOut);
86 tree->Branch("train_sig_outputs", &trainSigOut);
87 tree->Branch("test_bg_outputs", &testBgOut);
88 tree->Branch("test_sig_outputs", &testSigOut);
89
90 tree->Fill();
91 outfile->cd();
92 outfile->Write();
93 outfile->Close();
94}
void readFromStream(std::istream &is)
read all the necessary data from stream and fill the Forest and the FeatureBinnings NOTE: uses FastBD...
std::string m_PARAMtrainSampleFileName
training sample file name
std::string m_PARAMtestSampleFileName
test sample file name
std::multimap< int, double > m_trainOutput
map containing output for each training event
void initialize() override
Module initialization.
std::multimap< int, double > m_testOutput
map containing output for each test event
std::vector< TrainSample > m_trainSample
vector for training sample
std::vector< TrainSample > m_testSample
vector for test sample
Belle2::FBDTClassifier< 9 > m_classifier
classifier
Base class for Modules.
Definition: Module.h:72
void setDescription(const std::string &description)
Sets the description of the module.
Definition: Module.cc:214
virtual void event()
This method is the core of the module.
Definition: Module.h:157
void addParam(const std::string &name, T &paramVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module.
Definition: Module.h:560
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
Definition: Module.h:650
double analyze(const std::array< double, Ndims > &hits) const
calculate the output of the FastBDT.
std::vector< typename MapType::mapped_type > getValuesToKey(const MapType &aMap, typename MapType::key_type aKey)
get all values stored in the map for a given key
static void readSamplesFromStream(std::istream &is, std::vector< FBDTTrainSample< Ndims > > &samples)
read samples from stream and append them to samples
Abstract base class for different kinds of events.