Belle II Software development
CDCTrigger3DHNeuroDataModule.h
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#pragma once
10
11#include <vector>
12#include <string>
13
14#include <boost/iostreams/filter/gzip.hpp>
15#include <boost/iostreams/filtering_streambuf.hpp>
16#include <boost/iostreams/filtering_stream.hpp>
17#include <boost/multi_array.hpp>
18
19#include "trg/cdc/NeuroTrigger3DH.h"
20#include "trg/cdc/dataobjects/CDCTriggerHoughMLP.h"
21#include "trg/cdc/dataobjects/CDCTrigger3DHTrack.h"
22#include "tracking/dataobjects/RecoTrack.h"
23#include "framework/core/Module.h"
24
25#define BOOST_MULTI_ARRAY_NO_GENERATORS
26
27namespace Belle2 {
32 // Class for the data generation for the network training data set using 3DHough Finder input tracks
33 class CDCTrigger3DHNeuroDataModule : public Module {
34 public:
35 // Struct for target results
36 struct TargetResult {
37 std::vector<float> targetVector;
38 unsigned short trackType;
39 };
40 // Enum for the (target) track type
41 enum class TrackType : unsigned short {
42 Real = 0,
43 Background = 1,
44 Fake = 2,
45 UnrelatedFake = 3,
46 Unknown = 4
47 };
48
49 // Construtor
51 // Destructor
53
54 virtual void initialize() override;
55 virtual void event() override;
56
57
58 private:
59 // Write the headline to the .gz file
60 void writeHeadline() const;
61 // Compute scaled target vector from reco track
62 TargetResult computeTargetVector(const CDCTrigger3DHTrack& ndFinderTrack, const bool isFakeEvent) const;
63 // Get the (target) track type
64 TrackType determineTrackType(const float classificationNNT, const bool isFakeEvent, const bool isUnrelatedFake) const;
65
66 // Data generation parameters
67 // Name of the StoreArray containing the input track segment hits
68 std::string m_hitCollectionName;
69 // Name of the StoreArray containing the input 3D tracks
70 std::string m_inputCollectionName;
71 // Name of the StoreArray containing the reconstructed tracks used as target values
72 std::string m_targetCollectionName;
73 // Name of the configuration file used in the module to load the network configuration
74 std::string m_configFileName;
75 // Name of gzip file where the training data is saved
76 std::string m_filename;
77 // Flag to save the 3DFinder tracks from fake events (no reconstructed track present)
78 bool m_saveFakeEventTracks;
79 // Flag to save the 3DFinder tracks that have no relation to a reconstructed track
80 bool m_saveFakeUnrelatedTracks;
81
82 // Parameters for the 3DHough input NeuroTrigger
83 NeuroParametersHough m_neuroParameters3DH;
84 // Instance of the 3DHough input NeuroTrigger
85 NeuroTrigger3DH m_neuroTrigger3DH;
86 // StoreArray of input tracks
87 StoreArray<CDCTrigger3DHTrack> m_ndFinderTracks;
88 // StoreArray of reco tracks
89 StoreArray<RecoTrack> m_recoTracks;
90
91 // Number of super layers
92 static constexpr size_t m_nSL = 9;
93 };
94
95}
virtual void initialize() override
Initialize the Module.
virtual void event() override
This method is the core of the module.
Module()
Constructor.
Definition Module.cc:30
Abstract base class for different kinds of events.