Belle II Software development
NeuroTrainer.h
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8#pragma once
9#include <iostream>
10#include <sstream>
11
12#include <string>
13#include <trg/cdc/dataobjects/CDCTriggerMLP.h>
14#include <framework/datastore/StoreArray.h>
15#include <framework/datastore/StoreObjPtr.h>
16#include <trg/cdc/dataobjects/CDCTriggerSegmentHit.h>
17#include <mdst/dataobjects/MCParticle.h>
18#include <tracking/dataobjects/RecoTrack.h>
19#include <trg/cdc/dataobjects/CDCTriggerMLPData.h>
20#include <trg/cdc/dataobjects/CDCTriggerTrack.h>
21#include <trg/cdc/dataobjects/CDCTriggerMLPData.h>
22#include <framework/dataobjects/EventMetaData.h>
23#include <cdc/geometry/CDCGeometryPar.h>
24#include <framework/gearbox/Unit.h>
25#include <cmath>
26#include <TFile.h>
27#include <framework/geometry/B2Vector3.h>
28
29
30namespace Belle2::NeuroTrainer {
31 std::vector<float> getTrainTargets(bool& trainonreco, CDCTriggerTrack* twodtrack, std::string targetcollectionname)
32 {
33 std::vector<float> ret;
34 float phi0Target = 0;
35 float invptTarget = 0;
36 float thetaTarget = 0;
37 float zTarget = 0;
38 float isvalid = 1;
39 if (trainonreco) {
40 RecoTrack* recoTrack =
41 twodtrack->getRelatedTo<RecoTrack>(targetcollectionname);
42 if (!recoTrack) {
43 B2DEBUG(29, "Skipping CDCTriggerTrack without relation to RecoTrack.");
44 isvalid = 0;
45 } else {
46 // a RecoTrack has multiple representations for different particle hypothesis
47 // -> just take the first one that does not give errors.
48 const std::vector<genfit::AbsTrackRep*>& reps = recoTrack->getRepresentations();
49 bool foundValidRep = false;
50 for (unsigned irep = 0; irep < reps.size() && !foundValidRep; ++irep) {
51 if (!recoTrack->wasFitSuccessful(reps[irep]))
52 continue;
53 // get state (position, momentum etc.) from hit closest to IP and
54 // extrapolate to z-axis (may throw an exception -> continue to next representation)
55 try {
56 genfit::MeasuredStateOnPlane state =
57 recoTrack->getMeasuredStateOnPlaneClosestTo(ROOT::Math::XYZVector(0, 0, 0), reps[irep]);
58 reps[irep]->extrapolateToLine(state, TVector3(0, 0, -1000), TVector3(0, 0, 2000));
59 // flip tracks if necessary, such that trigger tracks and reco tracks
60 // point in the same direction
61 if (state.getMom().Dot(B2Vector3D(twodtrack->getDirection())) < 0) {
62 state.setPosMom(state.getPos(), -state.getMom());
63 state.setChargeSign(-state.getCharge());
64 }
65 // get track parameters
66 phi0Target = state.getMom().Phi();
67 invptTarget = state.getCharge() / state.getMom().Pt();
68 thetaTarget = state.getMom().Theta();
69 zTarget = state.getPos().Z();
70 } catch (...) {
71 continue;
72 }
73 // break loop
74 foundValidRep = true;
75 }
76 }
77 } else {
78 MCParticle* mcTrack =
79 twodtrack->getRelatedTo<MCParticle>(targetcollectionname);
80 if (not mcTrack) {
81 B2DEBUG(29, "Skipping CDCTriggerTrack without relation to MCParticle.");
82 isvalid = 0;
83 } else {
84 phi0Target = mcTrack->getMomentum().Phi();
85 invptTarget = mcTrack->getCharge() / mcTrack->getMomentum().Rho();
86 thetaTarget = mcTrack->getMomentum().Theta();
87 zTarget = mcTrack->getProductionVertex().Z();
88 }
89 }
90 ret.push_back(phi0Target);
91 ret.push_back(invptTarget);
92 ret.push_back(thetaTarget);
93 ret.push_back(zTarget);
94 ret.push_back(isvalid);
95 return ret;
96 }
97 std::vector<float>
98 getRelevantID(CDCTriggerMLPData& trainSet_prepare, double cutsum, double relevantcut)
99 {
100 std::vector<float> relevantID;
101 relevantID.assign(18, 0.);
102 CDC::CDCGeometryPar& cdc = CDC::CDCGeometryPar::Instance();
103 int layerId = 3;
104 for (unsigned iSL = 0; iSL < 9; ++iSL) {
105 int nWires = cdc.nWiresInLayer(layerId);
106 layerId += (iSL > 0 ? 6 : 7);
107 B2DEBUG(28, "SL " << iSL << " (" << nWires << " wires)");
108 // get maximum hit counter
109 unsigned maxCounter = 0;
110 int maxId = 0;
111 unsigned counterSum = 0;
112 for (int iTS = 0; iTS < nWires; ++iTS) {
113 if (trainSet_prepare.getHitCounter(iSL, iTS) > 0)
114 B2DEBUG(28, iTS << " " << trainSet_prepare.getHitCounter(iSL, iTS));
115 if (trainSet_prepare.getHitCounter(iSL, iTS) > maxCounter) {
116 maxCounter = trainSet_prepare.getHitCounter(iSL, iTS);
117 maxId = iTS;
118 }
119 counterSum += trainSet_prepare.getHitCounter(iSL, iTS);
120 }
121 // use maximum as starting range
122 if (maxId > nWires / 2) maxId -= nWires;
123 relevantID[2 * iSL] = maxId;
124 relevantID[2 * iSL + 1] = maxId;
125 if (cutsum) {
126 // add neighboring wire with higher hit count
127 // until sum over unused wires is less than relevantcut * sum over all wires
128 double cut = relevantcut * counterSum;
129 B2DEBUG(28, "Threshold on counterSum: " << cut);
130 unsigned relevantSum = maxCounter;
131 while (counterSum - relevantSum > cut) {
132 int prev = trainSet_prepare.getHitCounter(iSL, relevantID[2 * iSL] - 1);
133 int next = trainSet_prepare.getHitCounter(iSL, relevantID[2 * iSL + 1] + 1);
134 if (prev > next ||
135 (prev == next &&
136 (relevantID[2 * iSL + 1] - maxId) > (maxId - relevantID[2 * iSL]))) {
137 --relevantID[2 * iSL];
138 relevantSum += prev;
139 if (relevantID[2 * iSL] <= -nWires) break;
140 } else {
141 ++relevantID[2 * iSL + 1];
142 relevantSum += next;
143 if (relevantID[2 * iSL + 1] >= nWires - 1) break;
144 }
145 }
146 } else {
147 // add wires from both sides until hit counter drops below relevantcut * track counter
148 double cut = relevantcut * trainSet_prepare.getTrackCounter();
149 B2DEBUG(28, "Threshold on counter: " << cut);
150 while (trainSet_prepare.getHitCounter(iSL, relevantID[2 * iSL] - 1) > cut) {
151 --relevantID[2 * iSL];
152 if (relevantID[2 * iSL] <= -nWires) break;
153 }
154 while (trainSet_prepare.getHitCounter(iSL, relevantID[2 * iSL + 1] + 1) > cut) {
155 ++relevantID[2 * iSL + 1];
156 if (relevantID[2 * iSL + 1] >= nWires - 1) break;
157 }
158 }
159 // add +-0.5 to account for rounding during preparation
160 relevantID[2 * iSL] -= 0.5;
161 relevantID[2 * iSL + 1] += 0.5;
162 B2DEBUG(28, "SL " << iSL << ": "
163 << relevantID[2 * iSL] << " " << relevantID[2 * iSL + 1]);
164 }
165 return relevantID;
166 }
167
168
169}
static CDCGeometryPar & Instance(const CDCGeometry *=nullptr)
Static method to get a reference to the CDCGeometryPar instance.
TO * getRelatedTo(const std::string &name="", const std::string &namedRelation="") const
Get the object to which this object has a relation.
B2Vector3< double > B2Vector3D
typedef for common usage with double
Definition: B2Vector3.h:516