Belle II Software  release-08-01-10
NeuroTrainer.h
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 #pragma once
9 #include <iostream>
10 #include <sstream>
11 
12 #include <string>
13 #include <trg/cdc/dataobjects/CDCTriggerMLP.h>
14 #include <framework/datastore/StoreArray.h>
15 #include <framework/datastore/StoreObjPtr.h>
16 #include <trg/cdc/dataobjects/CDCTriggerSegmentHit.h>
17 #include <mdst/dataobjects/MCParticle.h>
18 #include <tracking/dataobjects/RecoTrack.h>
19 #include <trg/cdc/dataobjects/CDCTriggerMLPData.h>
20 #include <trg/cdc/dataobjects/CDCTriggerTrack.h>
21 #include <trg/cdc/dataobjects/CDCTriggerMLPData.h>
22 #include <framework/dataobjects/EventMetaData.h>
23 #include <cdc/geometry/CDCGeometryPar.h>
24 #include <framework/gearbox/Unit.h>
25 #include <cmath>
26 #include <TFile.h>
27 #include <framework/geometry/B2Vector3.h>
28 
29 
30 using namespace Belle2;
31 namespace NeuroTrainer {
32  std::vector<float> getTrainTargets(bool& trainonreco, CDCTriggerTrack* twodtrack, std::string targetcollectionname)
33  {
34  std::vector<float> ret;
35  float phi0Target = 0;
36  float invptTarget = 0;
37  float thetaTarget = 0;
38  float zTarget = 0;
39  float isvalid = 1;
40  if (trainonreco) {
41  RecoTrack* recoTrack =
42  twodtrack->getRelatedTo<RecoTrack>(targetcollectionname);
43  if (!recoTrack) {
44  B2DEBUG(29, "Skipping CDCTriggerTrack without relation to RecoTrack.");
45  isvalid = 0;
46  } else {
47  // a RecoTrack has multiple representations for different particle hypothesis
48  // -> just take the first one that does not give errors.
49  const std::vector<genfit::AbsTrackRep*>& reps = recoTrack->getRepresentations();
50  bool foundValidRep = false;
51  for (unsigned irep = 0; irep < reps.size() && !foundValidRep; ++irep) {
52  if (!recoTrack->wasFitSuccessful(reps[irep]))
53  continue;
54  // get state (position, momentum etc.) from hit closest to IP and
55  // extrapolate to z-axis (may throw an exception -> continue to next representation)
56  try {
58  recoTrack->getMeasuredStateOnPlaneClosestTo(ROOT::Math::XYZVector(0, 0, 0), reps[irep]);
59  reps[irep]->extrapolateToLine(state, TVector3(0, 0, -1000), TVector3(0, 0, 2000));
60  // flip tracks if necessary, such that trigger tracks and reco tracks
61  // point in the same direction
62  if (state.getMom().Dot(B2Vector3D(twodtrack->getDirection())) < 0) {
63  state.setPosMom(state.getPos(), -state.getMom());
64  state.setChargeSign(-state.getCharge());
65  }
66  // get track parameters
67  phi0Target = state.getMom().Phi();
68  invptTarget = state.getCharge() / state.getMom().Pt();
69  thetaTarget = state.getMom().Theta();
70  zTarget = state.getPos().Z();
71  } catch (...) {
72  continue;
73  }
74  // break loop
75  foundValidRep = true;
76  }
77  }
78  } else {
79  MCParticle* mcTrack =
80  twodtrack->getRelatedTo<MCParticle>(targetcollectionname);
81  if (not mcTrack) {
82  B2DEBUG(29, "Skipping CDCTriggerTrack without relation to MCParticle.");
83  isvalid = 0;
84  } else {
85  phi0Target = mcTrack->getMomentum().Phi();
86  invptTarget = mcTrack->getCharge() / mcTrack->getMomentum().Rho();
87  thetaTarget = mcTrack->getMomentum().Theta();
88  zTarget = mcTrack->getProductionVertex().Z();
89  }
90  }
91  ret.push_back(phi0Target);
92  ret.push_back(invptTarget);
93  ret.push_back(thetaTarget);
94  ret.push_back(zTarget);
95  ret.push_back(isvalid);
96  return ret;
97  }
98  std::vector<float>
99  getRelevantID(CDCTriggerMLPData& trainSet_prepare, double cutsum, double relevantcut)
100  {
101  std::vector<float> relevantID;
102  relevantID.assign(18, 0.);
104  int layerId = 3;
105  for (unsigned iSL = 0; iSL < 9; ++iSL) {
106  int nWires = cdc.nWiresInLayer(layerId);
107  layerId += (iSL > 0 ? 6 : 7);
108  B2DEBUG(28, "SL " << iSL << " (" << nWires << " wires)");
109  // get maximum hit counter
110  unsigned maxCounter = 0;
111  int maxId = 0;
112  unsigned counterSum = 0;
113  for (int iTS = 0; iTS < nWires; ++iTS) {
114  if (trainSet_prepare.getHitCounter(iSL, iTS) > 0)
115  B2DEBUG(28, iTS << " " << trainSet_prepare.getHitCounter(iSL, iTS));
116  if (trainSet_prepare.getHitCounter(iSL, iTS) > maxCounter) {
117  maxCounter = trainSet_prepare.getHitCounter(iSL, iTS);
118  maxId = iTS;
119  }
120  counterSum += trainSet_prepare.getHitCounter(iSL, iTS);
121  }
122  // use maximum as starting range
123  if (maxId > nWires / 2) maxId -= nWires;
124  relevantID[2 * iSL] = maxId;
125  relevantID[2 * iSL + 1] = maxId;
126  if (cutsum) {
127  // add neighboring wire with higher hit count
128  // until sum over unused wires is less than relevantcut * sum over all wires
129  double cut = relevantcut * counterSum;
130  B2DEBUG(28, "Threshold on counterSum: " << cut);
131  unsigned relevantSum = maxCounter;
132  while (counterSum - relevantSum > cut) {
133  int prev = trainSet_prepare.getHitCounter(iSL, relevantID[2 * iSL] - 1);
134  int next = trainSet_prepare.getHitCounter(iSL, relevantID[2 * iSL + 1] + 1);
135  if (prev > next ||
136  (prev == next &&
137  (relevantID[2 * iSL + 1] - maxId) > (maxId - relevantID[2 * iSL]))) {
138  --relevantID[2 * iSL];
139  relevantSum += prev;
140  if (relevantID[2 * iSL] <= -nWires) break;
141  } else {
142  ++relevantID[2 * iSL + 1];
143  relevantSum += next;
144  if (relevantID[2 * iSL + 1] >= nWires - 1) break;
145  }
146  }
147  } else {
148  // add wires from both sides until hit counter drops below relevantcut * track counter
149  double cut = relevantcut * trainSet_prepare.getTrackCounter();
150  B2DEBUG(28, "Threshold on counter: " << cut);
151  while (trainSet_prepare.getHitCounter(iSL, relevantID[2 * iSL] - 1) > cut) {
152  --relevantID[2 * iSL];
153  if (relevantID[2 * iSL] <= -nWires) break;
154  }
155  while (trainSet_prepare.getHitCounter(iSL, relevantID[2 * iSL + 1] + 1) > cut) {
156  ++relevantID[2 * iSL + 1];
157  if (relevantID[2 * iSL + 1] >= nWires - 1) break;
158  }
159  }
160  // add +-0.5 to account for rounding during preparation
161  relevantID[2 * iSL] -= 0.5;
162  relevantID[2 * iSL + 1] += 0.5;
163  B2DEBUG(28, "SL " << iSL << ": "
164  << relevantID[2 * iSL] << " " << relevantID[2 * iSL + 1]);
165  }
166  return relevantID;
167  }
168 
169 
170 }
Struct for training data of a single MLP for the neuro trigger.
short getTrackCounter() const
get track counter
unsigned short getHitCounter(unsigned iSL, int iTS) const
get hit counter for super layer and track segment number is super layer.
Track created by the CDC trigger.
The Class for CDC Geometry Parameters.
static CDCGeometryPar & Instance(const CDCGeometry *=nullptr)
Static method to get a reference to the CDCGeometryPar instance.
A Class to store the Monte Carlo particle information.
Definition: MCParticle.h:32
This is the Reconstruction Event-Data Model Track.
Definition: RecoTrack.h:79
bool wasFitSuccessful(const genfit::AbsTrackRep *representation=nullptr) const
Returns true if the last fit with the given representation was successful.
Definition: RecoTrack.cc:336
const genfit::MeasuredStateOnPlane & getMeasuredStateOnPlaneClosestTo(const ROOT::Math::XYZVector &closestPoint, const genfit::AbsTrackRep *representation=nullptr)
Return genfit's MasuredStateOnPlane, that is closest to the given point useful for extrapolation of m...
Definition: RecoTrack.cc:426
const std::vector< genfit::AbsTrackRep * > & getRepresentations() const
Return a list of track representations. You are not allowed to modify or delete them!
Definition: RecoTrack.h:638
#StateOnPlane with additional covariance matrix.
Abstract base class for different kinds of events.