Belle II Software development
FastBDTClassifierTrainingModule.h
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#pragma once
10
11#include <tracking/trackFindingVXD/filterTools/FBDTClassifierHelper.h> // for the FBDTTrainSample
12#include <tracking/trackFindingVXD/segmentNetwork/DirectedNodeNetworkContainer.h>
13
14#include <tracking/spacePointCreation/SpacePoint.h>
15
16#include <framework/core/Module.h>
17#include <framework/datastore/StoreObjPtr.h>
18
19#include <string>
20#include <vector>
21
22namespace Belle2 {
27
28
36
37 public:
38
41
43 void initialize() override;
44
46 void event() override;
47
49 void terminate() override;
50
51 private:
52
55
58
61
64
67
70
73
76
79
82
85
87 std::vector<TrainSample> m_samples;
88
91
93 const TrainSample makeTrainSample(const Belle2::SpacePoint* outerHit, const Belle2::SpacePoint* centerHit,
94 const Belle2::SpacePoint* innerHit);
95 };
96
98} // end namespace Belle2
double m_PARAMrandRatio
ratio of samples to be used for training one tree in the FastBDT.
std::string m_PARAMsamplesFileName
filename to be used to store / read collect samples
Belle2::StoreObjPtr< Belle2::DirectedNodeNetworkContainer > m_network
StoreObjPtr to access the DNNs that are used in this module.
void event() override
collect all possible combinations and store them
std::vector< TrainSample > m_samples
vector in which all samples are collected on the fly in event.
const TrainSample makeTrainSample(const Belle2::SpacePoint *outerHit, const Belle2::SpacePoint *centerHit, const Belle2::SpacePoint *innerHit)
create a trainings sample from the three hit combination
void terminate() override
take the collected data and train a FBDTClassifier and store it in the given output file
Belle2::FBDTTrainSample< 9 > TrainSample
< private typedef for shorter notation
bool m_PARAMdoTrain
actually train a classifier or only do collection
std::string m_PARAMfbdtOutFileName
output file name into which the FBDTClassifier is stored.
double m_PARAMshrinkage
shrinkage parameter of FastBDT.
bool m_PARAMstoreSamples
store the collected samples into a file
std::string m_PARAMnetworkInputName
name of the StoreObjPtr in which the network container is stored which contains the network that is u...
bool m_PARAMuseSamples
use pre-collected samples for training and bypass the collection step
Module()
Constructor.
Definition Module.cc:30
SpacePoint typically is build from 1 PXDCluster or 1-2 SVDClusters.
Definition SpacePoint.h:42
Type-safe access to single objects in the data store.
Definition StoreObjPtr.h:96
Abstract base class for different kinds of events.
bundle together the classifier input and the target value into one struct for easier passing around.