13#include <trg/cdc/dataobjects/CDCTriggerMLP.h>
14#include <framework/datastore/StoreArray.h>
15#include <framework/datastore/StoreObjPtr.h>
16#include <trg/cdc/dataobjects/CDCTriggerSegmentHit.h>
17#include <mdst/dataobjects/MCParticle.h>
18#include <tracking/dataobjects/RecoTrack.h>
19#include <trg/cdc/dataobjects/CDCTriggerMLPData.h>
20#include <trg/cdc/dataobjects/CDCTriggerTrack.h>
21#include <trg/cdc/dataobjects/CDCTriggerMLPData.h>
22#include <framework/dataobjects/EventMetaData.h>
23#include <cdc/geometry/CDCGeometryPar.h>
24#include <framework/gearbox/Unit.h>
27#include <framework/geometry/B2Vector3.h>
30namespace Belle2::NeuroTrainer {
31 std::vector<float> getTrainTargets(
bool& trainonreco, CDCTriggerTrack* twodtrack, std::string targetcollectionname)
33 std::vector<float> ret;
35 float invptTarget = 0;
36 float thetaTarget = 0;
40 RecoTrack* recoTrack =
41 twodtrack->getRelatedTo<RecoTrack>(targetcollectionname);
43 B2DEBUG(29,
"Skipping CDCTriggerTrack without relation to RecoTrack.");
48 const std::vector<genfit::AbsTrackRep*>& reps = recoTrack->getRepresentations();
49 bool foundValidRep =
false;
50 for (
unsigned irep = 0; irep < reps.size() && !foundValidRep; ++irep) {
51 if (!recoTrack->wasFitSuccessful(reps[irep]))
56 genfit::MeasuredStateOnPlane state =
57 recoTrack->getMeasuredStateOnPlaneClosestTo(ROOT::Math::XYZVector(0, 0, 0), reps[irep]);
58 reps[irep]->extrapolateToLine(state, TVector3(0, 0, -1000), TVector3(0, 0, 2000));
61 if (state.getMom().Dot(
B2Vector3D(twodtrack->getDirection())) < 0) {
62 state.setPosMom(state.getPos(), -state.getMom());
63 state.setChargeSign(-state.getCharge());
66 phi0Target = state.getMom().Phi();
67 invptTarget = state.getCharge() / state.getMom().Pt();
68 thetaTarget = state.getMom().Theta();
69 zTarget = state.getPos().Z();
79 twodtrack->
getRelatedTo<MCParticle>(targetcollectionname);
81 B2DEBUG(29,
"Skipping CDCTriggerTrack without relation to MCParticle.");
84 phi0Target = mcTrack->getMomentum().Phi();
85 invptTarget = mcTrack->getCharge() / mcTrack->getMomentum().Rho();
86 thetaTarget = mcTrack->getMomentum().Theta();
87 zTarget = mcTrack->getProductionVertex().Z();
90 ret.push_back(phi0Target);
91 ret.push_back(invptTarget);
92 ret.push_back(thetaTarget);
93 ret.push_back(zTarget);
94 ret.push_back(isvalid);
98 getRelevantID(CDCTriggerMLPData& trainSet_prepare,
double cutsum,
double relevantcut)
100 std::vector<float> relevantID;
101 relevantID.assign(18, 0.);
104 for (
unsigned iSL = 0; iSL < 9; ++iSL) {
105 int nWires = cdc.nWiresInLayer(layerId);
106 layerId += (iSL > 0 ? 6 : 7);
107 B2DEBUG(28,
"SL " << iSL <<
" (" << nWires <<
" wires)");
109 unsigned maxCounter = 0;
111 unsigned counterSum = 0;
112 for (
int iTS = 0; iTS < nWires; ++iTS) {
113 if (trainSet_prepare.getHitCounter(iSL, iTS) > 0)
114 B2DEBUG(28, iTS <<
" " << trainSet_prepare.getHitCounter(iSL, iTS));
115 if (trainSet_prepare.getHitCounter(iSL, iTS) > maxCounter) {
116 maxCounter = trainSet_prepare.getHitCounter(iSL, iTS);
119 counterSum += trainSet_prepare.getHitCounter(iSL, iTS);
122 if (maxId > nWires / 2) maxId -= nWires;
123 relevantID[2 * iSL] = maxId;
124 relevantID[2 * iSL + 1] = maxId;
128 double cut = relevantcut * counterSum;
129 B2DEBUG(28,
"Threshold on counterSum: " << cut);
130 unsigned relevantSum = maxCounter;
131 while (counterSum - relevantSum > cut) {
132 int prev = trainSet_prepare.getHitCounter(iSL, relevantID[2 * iSL] - 1);
133 int next = trainSet_prepare.getHitCounter(iSL, relevantID[2 * iSL + 1] + 1);
136 (relevantID[2 * iSL + 1] - maxId) > (maxId - relevantID[2 * iSL]))) {
137 --relevantID[2 * iSL];
139 if (relevantID[2 * iSL] <= -nWires)
break;
141 ++relevantID[2 * iSL + 1];
143 if (relevantID[2 * iSL + 1] >= nWires - 1)
break;
148 double cut = relevantcut * trainSet_prepare.getTrackCounter();
149 B2DEBUG(28,
"Threshold on counter: " << cut);
150 while (trainSet_prepare.getHitCounter(iSL, relevantID[2 * iSL] - 1) > cut) {
151 --relevantID[2 * iSL];
152 if (relevantID[2 * iSL] <= -nWires)
break;
154 while (trainSet_prepare.getHitCounter(iSL, relevantID[2 * iSL + 1] + 1) > cut) {
155 ++relevantID[2 * iSL + 1];
156 if (relevantID[2 * iSL + 1] >= nWires - 1)
break;
160 relevantID[2 * iSL] -= 0.5;
161 relevantID[2 * iSL + 1] += 0.5;
162 B2DEBUG(28,
"SL " << iSL <<
": "
163 << relevantID[2 * iSL] <<
" " << relevantID[2 * iSL + 1]);
static CDCGeometryPar & Instance(const CDCGeometry *=nullptr)
Static method to get a reference to the CDCGeometryPar instance.
TO * getRelatedTo(const std::string &name="", const std::string &namedRelation="") const
Get the object to which this object has a relation.
B2Vector3< double > B2Vector3D
typedef for common usage with double