1 #include "trg/cdc/modules/neurotrigger/CDCTriggerNeuroTrainerModule.h"
3 #include <parallel_fann.hpp>
8 #include <framework/datastore/StoreArray.h>
9 #include <mdst/dataobjects/MCParticle.h>
10 #include <tracking/dataobjects/RecoTrack.h>
11 #include <trg/cdc/dataobjects/CDCTriggerSegmentHit.h>
12 #include <trg/cdc/dataobjects/CDCTriggerTrack.h>
13 #include <framework/datastore/StoreObjPtr.h>
14 #include <framework/dataobjects/EventMetaData.h>
15 #include <framework/core/ModuleParam.templateDetails.h>
17 #include <cdc/geometry/CDCGeometryPar.h>
18 #include <framework/gearbox/Unit.h>
34 "The NeuroTriggerTrainer module of the CDC trigger.\n"
35 "Takes track segments and 2D track estimates to prepare input data\n"
36 "for the training of a neural network.\n"
37 "Networks are trained after the event loop and saved.\n\n"
38 "Data preparation is done in two steps:\n"
39 "1. The MLP uses hits from a limited range around the 2D track. "
40 "To find this range, a histogram with the distance of hits to the 2D track "
41 "is prepared. The relevant ID range is determined by a threshold on "
42 "the hit counters or on the sum of the hit counters over the relevant range.\n"
43 "2. Input data is calculated from the hits, the 2D tracks and the ID ranges. "
44 "Target data is collected from a MCParticle or RecoTrack related to the 2D track."
47 addParam(
"hitCollectionName", m_hitCollectionName,
48 "Name of the input StoreArray of CDCTriggerSegmentHits.",
50 addParam(
"EventTimeName", m_EventTimeName,
51 "Name of the event time object.",
53 addParam(
"inputCollectionName", m_inputCollectionName,
54 "Name of the StoreArray holding the 2D input tracks.",
55 string(
"TRGCDC2DFinderTracks"));
56 addParam(
"trainOnRecoTracks", m_trainOnRecoTracks,
57 "If true, use RecoTracks as targets instead of MCParticles.",
59 addParam(
"targetCollectionName", m_targetCollectionName,
60 "Name of the MCParticle/RecoTrack collection used as target values.",
61 string(
"MCParticles"));
62 addParam(
"filename", m_filename,
63 "Name of the root file where the NeuroTrigger parameters will be saved.",
64 string(
"NeuroTrigger.root"));
65 addParam(
"trainFilename", m_trainFilename,
66 "Name of the root file where the generated training samples will be saved.",
67 string(
"NeuroTrigger.root"));
68 addParam(
"logFilename", m_logFilename,
69 "Base name of the text files where the training logs will be saved "
70 "(two for each sector, named logFilename_BestRun_i.log "
71 "and logFilename_AllOptima_i.log).",
72 string(
"NeuroTrigger"));
73 addParam(
"arrayname", m_arrayname,
74 "Name of the TObjArray to hold the NeuroTrigger parameters.",
76 addParam(
"trainArrayname", m_trainArrayname,
77 "Name of the TObjArray to hold the training samples.",
79 addParam(
"saveDebug", m_saveDebug,
80 "If true, save parameter distribution of training data "
81 "in train file and training curve in log file.",
true);
82 addParam(
"load", m_load,
83 "Switch to load saved parameters if existing. "
84 "Take care not to duplicate training sets!",
false);
86 addParam(
"nMLP", m_parameters.nMLP,
87 "Number of expert MLPs.", m_parameters.nMLP);
88 addParam(
"nHidden", m_parameters.nHidden,
89 "Number of nodes in each hidden layer for all networks "
90 "or factor to multiply with number of inputs (1 list or nMLP lists). "
91 "The number of layers is derived from the shape.", m_parameters.nHidden);
92 addParam(
"multiplyHidden", m_parameters.multiplyHidden,
93 "If true, multiply nHidden with number of input nodes.",
94 m_parameters.multiplyHidden);
95 addParam(
"targetZ", m_parameters.targetZ,
96 "Train one output of MLP to give z.", m_parameters.targetZ);
97 addParam(
"targetTheta", m_parameters.targetTheta,
98 "Train one output of MLP to give theta.", m_parameters.targetTheta);
99 addParam(
"outputScale", m_parameters.outputScale,
100 "Output scale for all networks (1 value list or nMLP value lists). "
101 "Output[i] of the MLP is scaled from [-1, 1] "
102 "to [outputScale[2*i], outputScale[2*i+1]]. "
103 "(units: z[cm] / theta[degree])", m_parameters.outputScale);
104 addParam(
"phiRange", m_parameters.phiRange,
105 "Phi region in degree for which experts are trained. "
106 "1 value pair, nMLP value pairs or nPhi value pairs "
107 "with nPhi * nPt * nTheta * nPattern = nMLP.", m_parameters.phiRange);
108 addParam(
"invptRange", m_parameters.invptRange,
109 "Charge / Pt region in 1/GeV for which experts are trained. "
110 "1 value pair, nMLP value pairs or nPt value pairs "
111 "with nPhi * nPt * nTheta * nPattern = nMLP.", m_parameters.invptRange);
112 addParam(
"thetaRange", m_parameters.thetaRange,
113 "Theta region in degree for which experts are trained. "
114 "1 value pair, nMLP value pairs or nTheta value pairs "
115 "with nPhi * nPt * nTheta * nPattern = nMLP.", m_parameters.thetaRange);
116 addParam(
"phiRangeTrain", m_parameters.phiRangeTrain,
117 "Phi region in degree from which training events are taken. "
118 "Can be larger than phiRange to avoid edge effect.", m_parameters.phiRangeTrain);
119 addParam(
"invptRangeTrain", m_parameters.invptRangeTrain,
120 "Charge / Pt region in 1/GeV from which training events are taken. "
121 "Can be larger than phiRange to avoid edge effect.", m_parameters.invptRangeTrain);
122 addParam(
"thetaRangeTrain", m_parameters.thetaRangeTrain,
123 "Theta region in degree from which training events are taken. "
124 "Can be larger than phiRange to avoid edge effect.", m_parameters.thetaRangeTrain);
125 addParam(
"maxHitsPerSL", m_parameters.maxHitsPerSL,
126 "Maximum number of hits in a single SL. "
127 "1 value or same as SLpattern.", m_parameters.maxHitsPerSL);
128 addParam(
"SLpattern", m_parameters.SLpattern,
129 "Super layer pattern for which experts are trained. "
130 "1 value, nMLP values or nPattern values "
131 "with nPhi * nPt * nTheta * nPattern = nMLP.", m_parameters.SLpattern);
132 addParam(
"SLpatternMask", m_parameters.SLpatternMask,
133 "Super layer pattern mask for which experts are trained. "
134 "1 value or same as SLpattern.", m_parameters.SLpatternMask);
135 addParam(
"tMax", m_parameters.tMax,
136 "Maximal drift time (for scaling, unit: trigger timing bins).", m_parameters.tMax);
137 addParam(
"et_option", m_parameters.et_option,
138 "option on how to obtain the event time. Possibilities are: "
139 "'etf_only', 'fastestpriority', 'zero', 'etf_or_fastestpriority', 'etf_or_zero', 'etf_or_fastest2d', 'fastest2d'.",
140 m_parameters.et_option);
141 addParam(
"T0fromHits", m_parameters.T0fromHits,
142 "Deprecated, kept for backward compatibility. If true, the event time is "
143 "determined from all relevant hits in a sector, if there is no valid event "
144 "time from the event time finder. If false, no drift times are used if "
145 "there is no valid event time.",
146 m_parameters.T0fromHits);
147 addParam(
"selectSectorByMC", m_selectSectorByMC,
148 "If true, track parameters for sector selection are taken "
149 "from MCParticle instead of CDCTriggerTrack.",
false);
151 addParam(
"nTrainPrepare", m_nTrainPrepare,
152 "Number of samples for preparation of relevant ID ranges "
153 "(0: use default ranges).", 1000);
154 addParam(
"IDranges", m_IDranges,
155 "If list is not empty, it will replace the default ranges. "
156 "1 list or nMLP lists. Set nTrainPrepare to 0 if you use this option.",
158 addParam(
"relevantCut", m_relevantCut,
159 "Cut for preparation of relevant ID ranges.", 0.02);
160 addParam(
"cutSum", m_cutSum,
161 "If true, relevantCut is applied to the sum over hit counters, "
162 "otherwise directly on the hit counters.",
false);
163 addParam(
"nTrainMin", m_nTrainMin,
164 "Minimal number of training samples "
165 "or factor to multiply with number of weights. "
166 "If the minimal number of samples is not reached, "
167 "all samples are saved but no training is started.", 10.);
168 addParam(
"nTrainMax", m_nTrainMax,
169 "Maximal number of training samples "
170 "or factor to multiply with number of weights. "
171 "When the maximal number of samples is reached, "
172 "no further samples are added.", 10.);
173 addParam(
"multiplyNTrain", m_multiplyNTrain,
174 "If true, multiply nTrainMin and nTrainMax with number of weights.",
176 addParam(
"nValid", m_nValid,
177 "Number of validation samples for training.", 1000);
178 addParam(
"nTest", m_nTest,
179 "Number of test samples to get resolution after training.", 5000);
180 addParam(
"stopLoop", m_stopLoop,
181 "If true, stop event loop when maximal number of samples "
182 "is reached for all sectors.",
true);
183 addParam(
"rescaleTarget", m_rescaleTarget,
184 "If true, set target values > outputScale to 1, "
185 "else skip them.",
true);
187 addParam(
"wMax", m_wMax,
188 "Weights are limited to [-wMax, wMax] after each training epoch "
189 "(for convenience of the FPGA implementation).",
191 addParam(
"nThreads", m_nThreads,
192 "Number of threads for parallel training.", 1);
193 addParam(
"checkInterval", m_checkInterval,
194 "Training is stopped if validation error is higher than "
195 "checkInterval epochs ago, i.e. either the validation error is increasing "
196 "or the gain is less than the fluctuations.", 500);
197 addParam(
"maxEpochs", m_maxEpochs,
198 "Maximum number of training epochs.", 10000);
199 addParam(
"repeatTrain", m_repeatTrain,
200 "If >1, training is repeated several times with different start weights. "
201 "The weights which give the best resolution on the test samples are kept.", 1);
202 addParam(
"NeuroTrackInputMode", m_neuroTrackInputMode,
203 "When using real tracks, use neurotracks instead of 2dtracks as input to the neurotrigger",
212 m_tracks.isRequired(m_inputCollectionName);
213 if (m_trainOnRecoTracks) {
215 targets.isRequired(m_targetCollectionName);
218 targets.isRequired(m_targetCollectionName);
222 !loadTraindata(m_trainFilename, m_trainArrayname) ||
223 !m_NeuroTrigger.load(m_filename, m_arrayname)) {
224 m_NeuroTrigger.initialize(m_parameters);
227 for (
unsigned iMLP = 0; iMLP < m_NeuroTrigger.nSectors(); ++iMLP) {
230 for (
int iSL = 0; iSL < 9; ++iSL) {
231 m_trainSets[iMLP].addCounters(cdc.nWiresInLayer(layerId));
232 layerId += (iSL > 0 ? 6 : 7);
236 m_NeuroTrigger.initializeCollections(m_hitCollectionName, m_EventTimeName, m_parameters.et_option);
238 if (m_NeuroTrigger.nSectors() != m_trainSets.size())
239 B2ERROR(
"Number of training sets (" << m_trainSets.size() <<
") should match " <<
240 "number of sectors (" << m_NeuroTrigger.nSectors() <<
")");
241 if (m_nTrainMin > m_nTrainMax) {
242 m_nTrainMin = m_nTrainMax;
243 B2WARNING(
"nTrainMin set to " << m_nTrainMin <<
" (was larger than nTrainMax)");
246 if (m_IDranges.size() > 0) {
247 if (m_IDranges.size() == 1 || m_IDranges.size() == m_NeuroTrigger.nSectors()) {
248 B2DEBUG(50,
"Setting relevant ID ranges from parameters.");
249 for (
unsigned isector = 0; isector < m_NeuroTrigger.nSectors(); ++isector) {
250 unsigned iranges = (m_IDranges.size() == 1) ? 0 : isector;
251 if (m_IDranges[iranges].size() == 18)
252 m_NeuroTrigger[isector].relevantID = m_IDranges[iranges];
254 B2ERROR(
"IDranges must contain 18 values (sector " << isector
255 <<
" has " << m_IDranges[iranges].size() <<
")");
257 if (m_nTrainPrepare > 0)
258 B2WARNING(
"Given ID ranges will be replaced during training. "
259 "Set nTrainPrepare = 0 if you want to give ID ranges by hand.");
261 B2ERROR(
"Number of IDranges should be 0, 1, or " << m_NeuroTrigger.nSectors());
267 for (
unsigned iMLP = 0; iMLP < m_NeuroTrigger.nSectors(); ++iMLP) {
268 phiHistsMC.push_back(
269 new TH1D((
"phiMC" + to_string(iMLP)).c_str(),
270 (
"MC phi in sector " + to_string(iMLP)).c_str(),
271 100, -2 * M_PI, 2 * M_PI));
273 new TH1D((
"ptMC" + to_string(iMLP)).c_str(),
274 (
"MC charge / pt in sector " + to_string(iMLP)).c_str(),
276 thetaHistsMC.push_back(
277 new TH1D((
"thetaMC" + to_string(iMLP)).c_str(),
278 (
"MC theta in sector " + to_string(iMLP)).c_str(),
281 new TH1D((
"zMC" + to_string(iMLP)).c_str(),
282 (
"MC z in sector " + to_string(iMLP)).c_str(),
284 phiHists2D.push_back(
285 new TH1D((
"phi2D" + to_string(iMLP)).c_str(),
286 (
"2D phi in sector " + to_string(iMLP)).c_str(),
287 100, -2 * M_PI, 2 * M_PI));
289 new TH1D((
"pt2D" + to_string(iMLP)).c_str(),
290 (
"2D charge / pt in sector " + to_string(iMLP)).c_str(),
299 for (
int itrack = 0; itrack < m_tracks.getEntries(); ++itrack) {
302 float phi0Target = 0;
303 float invptTarget = 0;
304 float thetaTarget = 0;
306 if (m_trainOnRecoTracks) {
310 B2DEBUG(150,
"Skipping CDCTriggerTrack without relation to RecoTrack.");
316 bool foundValidRep =
false;
317 for (
unsigned irep = 0; irep < reps.size() && !foundValidRep; ++irep) {
325 reps[irep]->extrapolateToLine(state, TVector3(0, 0, -1000), TVector3(0, 0, 2000));
328 if (state.getMom().Dot(m_tracks[itrack]->getDirection()) < 0) {
329 state.setPosMom(state.getPos(), -state.getMom());
330 state.setChargeSign(-state.getCharge());
333 phi0Target = state.getMom().Phi();
334 invptTarget = state.getCharge() / state.getMom().Pt();
335 thetaTarget = state.getMom().Theta();
336 zTarget = state.getPos().Z();
341 foundValidRep =
true;
343 if (!foundValidRep) {
344 B2DEBUG(150,
"No valid representation found for RecoTrack, skipping.");
351 B2DEBUG(150,
"Skipping CDCTriggerTrack without relation to MCParticle.");
361 m_NeuroTrigger.updateTrack(*m_tracks[itrack]);
364 float phi0 = m_tracks[itrack]->getPhi0();
365 float invpt = m_tracks[itrack]->getKappa(1.5);
366 float theta = atan2(1., m_tracks[itrack]->getCotTheta());
367 if (m_selectSectorByMC) {
372 vector<int> sectors = m_NeuroTrigger.selectMLPs(phi0, invpt, theta);
373 if (sectors.size() == 0)
continue;
375 vector<float> targetRaw = {};
376 if (m_parameters.targetZ)
377 targetRaw.push_back(zTarget);
378 if (m_parameters.targetTheta)
379 targetRaw.push_back(thetaTarget);
380 for (
unsigned i = 0; i < sectors.size(); ++i) {
381 int isector = sectors[i];
382 vector<float> target = m_NeuroTrigger[isector].scaleTarget(targetRaw);
384 bool outOfRange =
false;
385 for (
unsigned itarget = 0; itarget < target.size(); ++itarget) {
386 if (fabs(target[itarget]) > 1.) {
388 target[itarget] /= fabs(target[itarget]);
391 if (!m_rescaleTarget && outOfRange)
continue;
393 if (m_nTrainPrepare > 0 &&
394 m_trainSets[isector].getTrackCounter() < m_nTrainPrepare) {
398 if (m_trainOnRecoTracks) {
404 double relId = m_NeuroTrigger.getRelId(hit);
405 m_trainSets[isector].
addHit(hit.getISuperLayer(), round(relId));
413 double relId = m_NeuroTrigger.getRelId(hit);
414 m_trainSets[isector].addHit(hit.getISuperLayer(), round(relId));
417 m_trainSets[isector].countTrack();
419 if (m_trainSets[isector].getTrackCounter() >= m_nTrainPrepare) {
420 updateRelevantID(isector);
424 float nTrainMax = m_multiplyNTrain ? m_nTrainMax * m_NeuroTrigger[isector].nWeights() : m_nTrainMax;
425 if (m_trainSets[isector].nSamples() > (nTrainMax + m_nValid + m_nTest)) {
429 m_NeuroTrigger.getEventTime(isector, *m_tracks[itrack], m_parameters.et_option, m_neuroTrackInputMode);
431 unsigned long hitPattern = m_NeuroTrigger.getInputPattern(isector, *m_tracks[itrack], m_neuroTrackInputMode);
432 unsigned long sectorPattern = m_NeuroTrigger[isector].getSLpattern();
433 B2DEBUG(250,
"hitPattern " << hitPattern <<
" sectorPattern " << sectorPattern);
434 if (sectorPattern > 0 && (sectorPattern & hitPattern) != sectorPattern) {
435 B2DEBUG(250,
"hitPattern not matching " << (sectorPattern & hitPattern));
439 vector<unsigned> hitIds;
440 if (m_neuroTrackInputMode) {
441 hitIds = m_NeuroTrigger.selectHitsHWSim(isector, *m_tracks[itrack]);
443 hitIds = m_NeuroTrigger.selectHits(isector, *m_tracks[itrack]);
445 m_trainSets[isector].addSample(m_NeuroTrigger.getInputVector(isector, hitIds), target);
447 phiHistsMC[isector]->Fill(phi0Target);
448 ptHistsMC[isector]->Fill(invptTarget);
449 thetaHistsMC[isector]->Fill(thetaTarget);
450 zHistsMC[isector]->Fill(zTarget);
451 phiHists2D[isector]->Fill(m_tracks[itrack]->getPhi0());
452 ptHists2D[isector]->Fill(m_tracks[itrack]->getKappa(1.5));
454 if (m_trainSets[isector].nSamples() % 1000 == 0) {
455 B2DEBUG(50, m_trainSets[isector].nSamples() <<
" samples collected for sector " << isector);
463 for (
unsigned isector = 0; isector < m_trainSets.size(); ++isector) {
464 float nTrainMax = m_multiplyNTrain ? m_nTrainMax * m_NeuroTrigger[isector].nWeights() : m_nTrainMax;
465 if (m_trainSets[isector].nSamples() < (nTrainMax + m_nValid + m_nTest)) {
471 B2INFO(
"Training sample preparation for NeuroTrigger finished, stopping event loop.");
473 eventMetaData->setEndOfData();
482 saveTraindata(m_trainFilename, m_trainArrayname);
484 for (
unsigned isector = 0; isector < m_NeuroTrigger.nSectors(); ++isector) {
486 if (m_NeuroTrigger[isector].isTrained())
488 float nTrainMin = m_multiplyNTrain ? m_nTrainMin * m_NeuroTrigger[isector].nWeights() : m_nTrainMin;
489 if (m_trainSets[isector].nSamples() < (nTrainMin + m_nValid + m_nTest)) {
490 B2WARNING(
"Not enough training samples for sector " << isector <<
" (" << (nTrainMin + m_nValid + m_nTest)
491 <<
" requested, " << m_trainSets[isector].nSamples() <<
" found)");
495 m_NeuroTrigger[isector].trained =
true;
497 vector<unsigned> indices = m_NeuroTrigger.getRangeIndices(m_parameters, isector);
498 vector<float> phiRange = m_parameters.phiRange[indices[0]];
499 vector<float> invptRange = m_parameters.invptRange[indices[1]];
500 vector<float> thetaRange = m_parameters.thetaRange[indices[2]];
506 m_NeuroTrigger[isector].phiRange = phiRange;
507 m_NeuroTrigger[isector].invptRange = invptRange;
508 m_NeuroTrigger[isector].thetaRange = thetaRange;
510 m_NeuroTrigger.save(m_filename, m_arrayname);
517 B2DEBUG(50,
"Setting relevant ID ranges for sector " << isector);
518 vector<float> relevantID;
519 relevantID.assign(18, 0.);
522 for (
unsigned iSL = 0; iSL < 9; ++iSL) {
523 int nWires = cdc.nWiresInLayer(layerId);
524 layerId += (iSL > 0 ? 6 : 7);
525 B2DEBUG(90,
"SL " << iSL <<
" (" << nWires <<
" wires)");
527 unsigned maxCounter = 0;
529 unsigned counterSum = 0;
530 for (
int iTS = 0; iTS < nWires; ++iTS) {
531 if (m_trainSets[isector].getHitCounter(iSL, iTS) > 0)
532 B2DEBUG(90, iTS <<
" " << m_trainSets[isector].getHitCounter(iSL, iTS));
533 if (m_trainSets[isector].getHitCounter(iSL, iTS) > maxCounter) {
534 maxCounter = m_trainSets[isector].getHitCounter(iSL, iTS);
537 counterSum += m_trainSets[isector].getHitCounter(iSL, iTS);
540 if (maxId > nWires / 2) maxId -= nWires;
541 relevantID[2 * iSL] = maxId;
542 relevantID[2 * iSL + 1] = maxId;
546 double cut = m_relevantCut * counterSum;
547 B2DEBUG(50,
"Threshold on counterSum: " << cut);
548 unsigned relevantSum = maxCounter;
549 while (counterSum - relevantSum > cut) {
550 int prev = m_trainSets[isector].getHitCounter(iSL, relevantID[2 * iSL] - 1);
551 int next = m_trainSets[isector].getHitCounter(iSL, relevantID[2 * iSL + 1] + 1);
554 (relevantID[2 * iSL + 1] - maxId) > (maxId - relevantID[2 * iSL]))) {
555 --relevantID[2 * iSL];
557 if (relevantID[2 * iSL] <= -nWires)
break;
559 ++relevantID[2 * iSL + 1];
561 if (relevantID[2 * iSL + 1] >= nWires - 1)
break;
566 double cut = m_relevantCut * m_trainSets[isector].getTrackCounter();
567 B2DEBUG(50,
"Threshold on counter: " << cut);
568 while (m_trainSets[isector].getHitCounter(iSL, relevantID[2 * iSL] - 1) > cut) {
569 --relevantID[2 * iSL];
570 if (relevantID[2 * iSL] <= -nWires)
break;
572 while (m_trainSets[isector].getHitCounter(iSL, relevantID[2 * iSL + 1] + 1) > cut) {
573 ++relevantID[2 * iSL + 1];
574 if (relevantID[2 * iSL + 1] >= nWires - 1)
break;
578 relevantID[2 * iSL] -= 0.5;
579 relevantID[2 * iSL + 1] += 0.5;
580 B2DEBUG(50,
"SL " << iSL <<
": "
581 << relevantID[2 * iSL] <<
" " << relevantID[2 * iSL + 1]);
583 m_NeuroTrigger[isector].relevantID = relevantID;
590 B2INFO(
"Training network for sector " << isector <<
" with OpenMP");
592 B2INFO(
"Training network for sector " << isector <<
" without OpenMP");
595 unsigned nLayers = m_NeuroTrigger[isector].nLayers();
596 unsigned* nNodes =
new unsigned[nLayers];
597 for (
unsigned il = 0; il < nLayers; ++il) {
598 nNodes[il] = m_NeuroTrigger[isector].nNodesLayer(il);
600 struct fann* ann = fann_create_standard_array(nLayers, nNodes);
604 unsigned nTrain = m_trainSets[isector].
nSamples() - m_nValid - m_nTest;
605 struct fann_train_data* train_data =
606 fann_create_train(nTrain, nNodes[0], nNodes[nLayers - 1]);
607 for (
unsigned i = 0; i < nTrain; ++i) {
608 vector<float> input = currentData.
getInput(i);
609 for (
unsigned j = 0; j < input.size(); ++j) {
610 train_data->input[i][j] = input[j];
612 vector<float> target = currentData.
getTarget(i);
613 for (
unsigned j = 0; j < target.size(); ++j) {
614 train_data->output[i][j] = target[j];
618 struct fann_train_data* valid_data =
619 fann_create_train(m_nValid, nNodes[0], nNodes[nLayers - 1]);
620 for (
unsigned i = nTrain; i < nTrain + m_nValid; ++i) {
621 vector<float> input = currentData.
getInput(i);
622 for (
unsigned j = 0; j < input.size(); ++j) {
623 valid_data->input[i - nTrain][j] = input[j];
625 vector<float> target = currentData.
getTarget(i);
626 for (
unsigned j = 0; j < target.size(); ++j) {
627 valid_data->output[i - nTrain][j] = target[j];
631 fann_set_activation_function_hidden(ann, FANN_SIGMOID_SYMMETRIC);
632 fann_set_activation_function_output(ann, FANN_SIGMOID_SYMMETRIC);
633 fann_set_training_algorithm(ann, FANN_TRAIN_RPROP);
634 double bestRMS = 999.;
636 vector<double> bestTrainLog = {};
637 vector<double> bestValidLog = {};
639 vector<double> trainOptLog = {};
640 vector<double> validOptLog = {};
642 for (
int irun = 0; irun < m_repeatTrain; ++irun) {
643 double bestValid = 999.;
644 vector<double> trainLog = {};
645 vector<double> validLog = {};
646 trainLog.assign(m_maxEpochs, 0.);
647 validLog.assign(m_maxEpochs, 0.);
650 vector<fann_type> bestWeights = {};
651 bestWeights.assign(m_NeuroTrigger[isector].nWeights(), 0.);
652 fann_randomize_weights(ann, -0.1, 0.1);
654 for (
int epoch = 1; epoch <= m_maxEpochs; ++epoch) {
656 double mse = parallel_fann::train_epoch_irpropm_parallel(ann, train_data, m_nThreads);
658 double mse = fann_train_epoch(ann, train_data);
660 trainLog[epoch - 1] = mse;
662 for (
unsigned iw = 0; iw < ann->total_connections; ++iw) {
663 if (ann->weights[iw] > m_wMax)
664 ann->weights[iw] = m_wMax;
665 else if (ann->weights[iw] < -m_wMax)
666 ann->weights[iw] = -m_wMax;
671 double valid_mse = parallel_fann::test_data_parallel(ann, valid_data, m_nThreads);
673 double valid_mse = fann_test_data(ann, valid_data);
675 validLog[epoch - 1] = valid_mse;
677 if (valid_mse < bestValid) {
678 bestValid = valid_mse;
679 for (
unsigned iw = 0; iw < ann->total_connections; ++iw) {
680 bestWeights[iw] = ann->weights[iw];
685 if (epoch > m_checkInterval && valid_mse > validLog[epoch - m_checkInterval]) {
686 B2INFO(
"Training run " << irun <<
" stopped in epoch " << epoch);
687 B2INFO(
"Train error: " << mse <<
", valid error: " << valid_mse <<
688 ", best valid: " << bestValid);
693 if (epoch == 1 || (epoch < 100 && epoch % 10 == 0) || epoch % 100 == 0) {
694 B2INFO(
"Epoch " << epoch <<
": Train error = " << mse <<
695 ", valid error = " << valid_mse <<
", best valid = " << bestValid);
698 if (breakEpoch == 0) {
699 B2INFO(
"Training run " << irun <<
" finished in epoch " << m_maxEpochs);
700 breakEpoch = m_maxEpochs;
702 trainOptLog.push_back(trainLog[bestEpoch - 1]);
703 validOptLog.push_back(validLog[bestEpoch - 1]);
705 vector<float> oldWeights = m_NeuroTrigger[isector].getWeights();
706 m_NeuroTrigger[isector].weights = bestWeights;
707 vector<double> sumSqr;
708 sumSqr.assign(nNodes[nLayers - 1], 0.);
709 for (
unsigned i = nTrain + m_nValid; i < m_trainSets[isector].nSamples(); ++i) {
710 vector<float> output = m_NeuroTrigger.runMLP(isector, m_trainSets[isector].getInput(i));
711 vector<float> target = m_trainSets[isector].getTarget(i);
712 for (
unsigned iout = 0; iout < output.size(); ++iout) {
713 float diff = output[iout] - m_NeuroTrigger[isector].unscaleTarget(target)[iout];
714 sumSqr[iout] += diff * diff;
717 double sumSqrTotal = 0;
718 if (m_parameters.targetZ) {
719 sumSqrTotal += sumSqr[m_NeuroTrigger[isector].zIndex()];
720 B2INFO(
"RMS z: " << sqrt(sumSqr[m_NeuroTrigger[isector].zIndex()] / m_nTest) <<
"cm");
722 if (m_parameters.targetTheta) {
724 sumSqrTotal += sumSqr[m_NeuroTrigger[isector].thetaIndex()];
725 B2INFO(
"RMS theta: " << sqrt(sumSqr[m_NeuroTrigger[isector].thetaIndex()] / m_nTest) <<
"deg");
727 double RMS = sqrt(sumSqrTotal / m_nTest / sumSqr.size());
728 B2INFO(
"RMS on test samples: " << RMS <<
" (best: " << bestRMS <<
")");
731 bestTrainLog.assign(trainLog.begin(), trainLog.begin() + breakEpoch);
732 bestValidLog.assign(validLog.begin(), validLog.begin() + breakEpoch);
734 m_NeuroTrigger[isector].weights = oldWeights;
740 ofstream logstream(m_logFilename +
"_BestRun_" + to_string(isector) +
".log");
741 for (
unsigned i = 0; i < bestTrainLog.size(); ++i) {
742 logstream << bestTrainLog[i] <<
" " << bestValidLog[i] << endl;
746 ofstream logstreamOpt(m_logFilename +
"_AllOptima_" + to_string(isector) +
".log");
747 for (
unsigned i = 0; i < trainOptLog.size(); ++i) {
748 logstreamOpt << trainOptLog[i] <<
" " << validOptLog[i] << endl;
750 logstreamOpt.close();
753 fann_destroy_train(train_data);
754 fann_destroy_train(valid_data);
762 B2INFO(
"Saving traindata to file " << filename <<
", array " << arrayname);
763 TFile datafile(filename.c_str(),
"UPDATE");
764 TObjArray* trainSets =
new TObjArray(m_trainSets.size());
765 for (
unsigned isector = 0; isector < m_trainSets.size(); ++isector) {
766 trainSets->Add(&m_trainSets[isector]);
768 phiHistsMC[isector]->Write(phiHistsMC[isector]->GetName(), TObject::kOverwrite);
769 ptHistsMC[isector]->Write(ptHistsMC[isector]->GetName(), TObject::kOverwrite);
770 thetaHistsMC[isector]->Write(thetaHistsMC[isector]->GetName(), TObject::kOverwrite);
771 zHistsMC[isector]->Write(zHistsMC[isector]->GetName(), TObject::kOverwrite);
772 phiHists2D[isector]->Write(phiHists2D[isector]->GetName(), TObject::kOverwrite);
773 ptHists2D[isector]->Write(ptHists2D[isector]->GetName(), TObject::kOverwrite);
776 trainSets->Write(arrayname.c_str(), TObject::kSingleKey | TObject::kOverwrite);
780 for (
unsigned isector = 0; isector < phiHistsMC.size(); ++ isector) {
781 delete phiHistsMC[isector];
782 delete ptHistsMC[isector];
783 delete thetaHistsMC[isector];
784 delete zHistsMC[isector];
785 delete phiHists2D[isector];
786 delete ptHists2D[isector];
790 thetaHistsMC.clear();
799 TFile datafile(filename.c_str(),
"READ");
800 if (!datafile.IsOpen()) {
801 B2WARNING(
"Could not open file " << filename);
804 TObjArray* trainSets = (TObjArray*)datafile.Get(arrayname.c_str());
807 B2WARNING(
"File " << filename <<
" does not contain key " << arrayname);
811 for (
int isector = 0; isector < trainSets->GetEntriesFast(); ++isector) {
813 if (samples) m_trainSets.push_back(*samples);
814 else B2WARNING(
"Wrong type " << trainSets->At(isector)->ClassName() <<
", ignoring this entry.");
819 B2DEBUG(100,
"loaded " << m_trainSets.size() <<
" training sets");