Belle II Software  release-08-01-10
FastBDTClassifierTrainingModule.h
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 #pragma once
10 
11 #include <tracking/trackFindingVXD/filterTools/FBDTClassifierHelper.h> // for the FBDTTrainSample
12 #include <tracking/trackFindingVXD/segmentNetwork/DirectedNodeNetworkContainer.h>
13 
14 #include <tracking/spacePointCreation/SpacePoint.h>
15 
16 #include <framework/core/Module.h>
17 #include <framework/datastore/StoreObjPtr.h>
18 
19 #include <string>
20 #include <vector>
21 
22 namespace Belle2 {
39 
40  public:
41 
44 
46  void initialize() override;
47 
49  void event() override;
50 
52  void terminate() override;
53 
54  private:
55 
58 
61 
64 
67 
70 
73 
76 
79 
82 
85 
88 
90  std::vector<TrainSample> m_samples;
91 
94 
96  const TrainSample makeTrainSample(const Belle2::SpacePoint* outerHit, const Belle2::SpacePoint* centerHit,
97  const Belle2::SpacePoint* innerHit);
98  };
99 
101 } // end namespace Belle2
Module for collecting the data and training a FastBDT classifier.
double m_PARAMrandRatio
ratio of samples to be used for training one tree in the FastBDT.
std::string m_PARAMsamplesFileName
filename to be used to store / read collect samples
void initialize() override
initialize the module
Belle2::StoreObjPtr< Belle2::DirectedNodeNetworkContainer > m_network
StoreObjPtr to access the DNNs that are used in this module.
void event() override
collect all possible combinations and store them
std::vector< TrainSample > m_samples
vector in which all samples are collected on the fly in event.
const TrainSample makeTrainSample(const Belle2::SpacePoint *outerHit, const Belle2::SpacePoint *centerHit, const Belle2::SpacePoint *innerHit)
create a trainings sample from the three hit combination
void terminate() override
take the collected data and train a FBDTClassifier and store it in the given output file
bool m_PARAMdoTrain
actually train a classifier or only do collection
std::string m_PARAMfbdtOutFileName
output file name into which the FBDTClassifier is stored.
double m_PARAMshrinkage
shrinkage parameter of FastBDT.
bool m_PARAMstoreSamples
store the collected samples into a file
std::string m_PARAMnetworkInputName
name of the StoreObjPtr in which the network container is stored which contains the network that is u...
bool m_PARAMuseSamples
use pre-collected samples for training and bypass the collection step
int m_PARAMnTrees
number of trees in the FastBDT.
Base class for Modules.
Definition: Module.h:72
SpacePoint typically is build from 1 PXDCluster or 1-2 SVDClusters.
Definition: SpacePoint.h:42
Abstract base class for different kinds of events.
bundle together the classifier input and the target value into one struct for easier passing around.