Belle II Software development
NeuroTrainer.h
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8#pragma once
9#include <iostream>
10#include <sstream>
11
12#include <string>
13#include <trg/cdc/dataobjects/CDCTriggerMLP.h>
14#include <framework/datastore/StoreArray.h>
15#include <framework/datastore/StoreObjPtr.h>
16#include <trg/cdc/dataobjects/CDCTriggerSegmentHit.h>
17#include <mdst/dataobjects/MCParticle.h>
18#include <tracking/dataobjects/RecoTrack.h>
19#include <trg/cdc/dataobjects/CDCTriggerMLPData.h>
20#include <trg/cdc/dataobjects/CDCTriggerTrack.h>
21#include <trg/cdc/dataobjects/CDCTriggerMLPData.h>
22#include <framework/dataobjects/EventMetaData.h>
23#include <cdc/geometry/CDCGeometryPar.h>
24#include <framework/gearbox/Unit.h>
25#include <cmath>
26#include <TFile.h>
27#include <framework/geometry/B2Vector3.h>
28
29
30using namespace Belle2;
31namespace NeuroTrainer {
32 std::vector<float> getTrainTargets(bool& trainonreco, CDCTriggerTrack* twodtrack, std::string targetcollectionname)
33 {
34 std::vector<float> ret;
35 float phi0Target = 0;
36 float invptTarget = 0;
37 float thetaTarget = 0;
38 float zTarget = 0;
39 float isvalid = 1;
40 if (trainonreco) {
41 RecoTrack* recoTrack =
42 twodtrack->getRelatedTo<RecoTrack>(targetcollectionname);
43 if (!recoTrack) {
44 B2DEBUG(29, "Skipping CDCTriggerTrack without relation to RecoTrack.");
45 isvalid = 0;
46 } else {
47 // a RecoTrack has multiple representations for different particle hypothesis
48 // -> just take the first one that does not give errors.
49 const std::vector<genfit::AbsTrackRep*>& reps = recoTrack->getRepresentations();
50 bool foundValidRep = false;
51 for (unsigned irep = 0; irep < reps.size() && !foundValidRep; ++irep) {
52 if (!recoTrack->wasFitSuccessful(reps[irep]))
53 continue;
54 // get state (position, momentum etc.) from hit closest to IP and
55 // extrapolate to z-axis (may throw an exception -> continue to next representation)
56 try {
57 genfit::MeasuredStateOnPlane state =
58 recoTrack->getMeasuredStateOnPlaneClosestTo(ROOT::Math::XYZVector(0, 0, 0), reps[irep]);
59 reps[irep]->extrapolateToLine(state, TVector3(0, 0, -1000), TVector3(0, 0, 2000));
60 // flip tracks if necessary, such that trigger tracks and reco tracks
61 // point in the same direction
62 if (state.getMom().Dot(B2Vector3D(twodtrack->getDirection())) < 0) {
63 state.setPosMom(state.getPos(), -state.getMom());
64 state.setChargeSign(-state.getCharge());
65 }
66 // get track parameters
67 phi0Target = state.getMom().Phi();
68 invptTarget = state.getCharge() / state.getMom().Pt();
69 thetaTarget = state.getMom().Theta();
70 zTarget = state.getPos().Z();
71 } catch (...) {
72 continue;
73 }
74 // break loop
75 foundValidRep = true;
76 }
77 }
78 } else {
79 MCParticle* mcTrack =
80 twodtrack->getRelatedTo<MCParticle>(targetcollectionname);
81 if (not mcTrack) {
82 B2DEBUG(29, "Skipping CDCTriggerTrack without relation to MCParticle.");
83 isvalid = 0;
84 } else {
85 phi0Target = mcTrack->getMomentum().Phi();
86 invptTarget = mcTrack->getCharge() / mcTrack->getMomentum().Rho();
87 thetaTarget = mcTrack->getMomentum().Theta();
88 zTarget = mcTrack->getProductionVertex().Z();
89 }
90 }
91 ret.push_back(phi0Target);
92 ret.push_back(invptTarget);
93 ret.push_back(thetaTarget);
94 ret.push_back(zTarget);
95 ret.push_back(isvalid);
96 return ret;
97 }
98 std::vector<float>
99 getRelevantID(CDCTriggerMLPData& trainSet_prepare, double cutsum, double relevantcut)
100 {
101 std::vector<float> relevantID;
102 relevantID.assign(18, 0.);
104 int layerId = 3;
105 for (unsigned iSL = 0; iSL < 9; ++iSL) {
106 int nWires = cdc.nWiresInLayer(layerId);
107 layerId += (iSL > 0 ? 6 : 7);
108 B2DEBUG(28, "SL " << iSL << " (" << nWires << " wires)");
109 // get maximum hit counter
110 unsigned maxCounter = 0;
111 int maxId = 0;
112 unsigned counterSum = 0;
113 for (int iTS = 0; iTS < nWires; ++iTS) {
114 if (trainSet_prepare.getHitCounter(iSL, iTS) > 0)
115 B2DEBUG(28, iTS << " " << trainSet_prepare.getHitCounter(iSL, iTS));
116 if (trainSet_prepare.getHitCounter(iSL, iTS) > maxCounter) {
117 maxCounter = trainSet_prepare.getHitCounter(iSL, iTS);
118 maxId = iTS;
119 }
120 counterSum += trainSet_prepare.getHitCounter(iSL, iTS);
121 }
122 // use maximum as starting range
123 if (maxId > nWires / 2) maxId -= nWires;
124 relevantID[2 * iSL] = maxId;
125 relevantID[2 * iSL + 1] = maxId;
126 if (cutsum) {
127 // add neighboring wire with higher hit count
128 // until sum over unused wires is less than relevantcut * sum over all wires
129 double cut = relevantcut * counterSum;
130 B2DEBUG(28, "Threshold on counterSum: " << cut);
131 unsigned relevantSum = maxCounter;
132 while (counterSum - relevantSum > cut) {
133 int prev = trainSet_prepare.getHitCounter(iSL, relevantID[2 * iSL] - 1);
134 int next = trainSet_prepare.getHitCounter(iSL, relevantID[2 * iSL + 1] + 1);
135 if (prev > next ||
136 (prev == next &&
137 (relevantID[2 * iSL + 1] - maxId) > (maxId - relevantID[2 * iSL]))) {
138 --relevantID[2 * iSL];
139 relevantSum += prev;
140 if (relevantID[2 * iSL] <= -nWires) break;
141 } else {
142 ++relevantID[2 * iSL + 1];
143 relevantSum += next;
144 if (relevantID[2 * iSL + 1] >= nWires - 1) break;
145 }
146 }
147 } else {
148 // add wires from both sides until hit counter drops below relevantcut * track counter
149 double cut = relevantcut * trainSet_prepare.getTrackCounter();
150 B2DEBUG(28, "Threshold on counter: " << cut);
151 while (trainSet_prepare.getHitCounter(iSL, relevantID[2 * iSL] - 1) > cut) {
152 --relevantID[2 * iSL];
153 if (relevantID[2 * iSL] <= -nWires) break;
154 }
155 while (trainSet_prepare.getHitCounter(iSL, relevantID[2 * iSL + 1] + 1) > cut) {
156 ++relevantID[2 * iSL + 1];
157 if (relevantID[2 * iSL + 1] >= nWires - 1) break;
158 }
159 }
160 // add +-0.5 to account for rounding during preparation
161 relevantID[2 * iSL] -= 0.5;
162 relevantID[2 * iSL + 1] += 0.5;
163 B2DEBUG(28, "SL " << iSL << ": "
164 << relevantID[2 * iSL] << " " << relevantID[2 * iSL + 1]);
165 }
166 return relevantID;
167 }
168
169
170}
Struct for training data of a single MLP for the neuro trigger.
short getTrackCounter() const
get track counter
unsigned short getHitCounter(unsigned iSL, int iTS) const
get hit counter for super layer and track segment number is super layer.
Track created by the CDC trigger.
The Class for CDC Geometry Parameters.
static CDCGeometryPar & Instance(const CDCGeometry *=nullptr)
Static method to get a reference to the CDCGeometryPar instance.
A Class to store the Monte Carlo particle information.
Definition: MCParticle.h:32
This is the Reconstruction Event-Data Model Track.
Definition: RecoTrack.h:79
const std::vector< genfit::AbsTrackRep * > & getRepresentations() const
Return a list of track representations. You are not allowed to modify or delete them!
Definition: RecoTrack.h:638
bool wasFitSuccessful(const genfit::AbsTrackRep *representation=nullptr) const
Returns true if the last fit with the given representation was successful.
Definition: RecoTrack.cc:336
const genfit::MeasuredStateOnPlane & getMeasuredStateOnPlaneClosestTo(const ROOT::Math::XYZVector &closestPoint, const genfit::AbsTrackRep *representation=nullptr)
Return genfit's MasuredStateOnPlane, that is closest to the given point useful for extrapolation of m...
Definition: RecoTrack.cc:426
Abstract base class for different kinds of events.