11 #include <svd/modules/svdReconstruction/SVDNNClusterizerModule.h>
13 #include <framework/datastore/StoreArray.h>
14 #include <framework/logging/Logger.h>
16 #include <vxd/geometry/GeoCache.h>
17 #include <svd/geometry/SensorInfo.h>
19 #include <mdst/dataobjects/MCParticle.h>
20 #include <svd/dataobjects/SVDTrueHit.h>
21 #include <svd/dataobjects/SVDShaperDigit.h>
22 #include <svd/dataobjects/SVDRecoDigit.h>
23 #include <svd/dataobjects/SVDCluster.h>
24 #include <mva/dataobjects/DatabaseRepresentationOfWeightfile.h>
26 #include <svd/reconstruction/NNWaveFitTool.h>
48 B2DEBUG(200,
"SVDNNClusterizerModule ctor");
50 setDescription(
"Clusterize SVDRecoDigits and reconstruct hits");
51 setPropertyFlags(c_ParallelProcessingCertified);
54 addParam(
"Digits", m_storeRecoDigitsName,
55 "RecoDigits collection name",
string(
""));
56 addParam(
"Clusters", m_storeClustersName,
57 "Cluster collection name",
string(
""));
58 addParam(
"TrueHits", m_storeTrueHitsName,
59 "TrueHit collection name",
string(
""));
60 addParam(
"MCParticles", m_storeMCParticlesName,
61 "MCParticles collection name",
string(
""));
64 addParam(
"TimeFitterName", m_timeFitterName,
65 "Name of time fitter data file",
string(
"SVDTimeNet_6samples"));
66 addParam(
"CalibratePeak", m_calibratePeak,
"Use calibrattion (vs. default) for peak widths and positions",
bool(
false));
70 addParam(
"NoiseSN", m_cutAdjacent,
71 "SN for digits to be considered for clustering", m_cutAdjacent);
72 addParam(
"SeedSN", m_cutSeed,
73 "SN for digits to be considered as seed", m_cutSeed);
74 addParam(
"ClusterSN", m_cutCluster,
75 "Minimum SN for clusters", m_cutCluster);
76 addParam(
"HeadTailSize", m_sizeHeadTail,
77 "Cluster size at which to switch to Analog head tail algorithm", m_sizeHeadTail);
81 void SVDNNClusterizerModule::initialize()
89 storeClusters.registerInDataStore();
90 storeRecoDigits.isRequired();
91 storeTrueHits.isOptional();
92 storeMCParticles.isOptional();
94 RelationArray relClusterRecoDigits(storeClusters, storeRecoDigits);
95 RelationArray relClusterTrueHits(storeClusters, storeTrueHits);
96 RelationArray relClusterMCParticles(storeClusters, storeMCParticles);
97 RelationArray relRecoDigitTrueHits(storeRecoDigits, storeTrueHits);
98 RelationArray relRecoDigitMCParticles(storeRecoDigits, storeMCParticles);
108 m_storeClustersName = storeClusters.getName();
109 m_storeRecoDigitsName = storeRecoDigits.getName();
110 m_storeTrueHitsName = storeTrueHits.getName();
111 m_storeMCParticlesName = storeMCParticles.getName();
113 m_relClusterRecoDigitName = relClusterRecoDigits.
getName();
114 m_relClusterTrueHitName = relClusterTrueHits.
getName();
115 m_relClusterMCParticleName = relClusterMCParticles.
getName();
116 m_relRecoDigitTrueHitName = relRecoDigitTrueHits.
getName();
117 m_relRecoDigitMCParticleName = relRecoDigitMCParticles.
getName();
119 B2INFO(
" 1. COLLECTIONS:");
120 B2INFO(
" --> MCParticles: " << m_storeMCParticlesName);
121 B2INFO(
" --> Digits: " << m_storeRecoDigitsName);
122 B2INFO(
" --> Clusters: " << m_storeClustersName);
123 B2INFO(
" --> TrueHits: " << m_storeTrueHitsName);
124 B2INFO(
" --> DigitMCRel: " << m_relRecoDigitMCParticleName);
125 B2INFO(
" --> ClusterMCRel: " << m_relClusterMCParticleName);
126 B2INFO(
" --> ClusterDigitRel: " << m_relClusterRecoDigitName);
127 B2INFO(
" --> DigitTrueRel: " << m_relRecoDigitTrueHitName);
128 B2INFO(
" --> ClusterTrueRel: " << m_relClusterTrueHitName);
129 B2INFO(
" 2. CALIBRATION DATA:");
130 B2INFO(
" --> Time NN: " << m_timeFitterName);
131 B2INFO(
" 4. CLUSTERING:");
132 B2INFO(
" --> Neighbour cut: " << m_cutAdjacent);
133 B2INFO(
" --> Seed cut: " << m_cutSeed);
134 B2INFO(
" --> Cluster charge cut: " << m_cutCluster);
135 B2INFO(
" --> HT for clusters >: " << m_sizeHeadTail);
141 m_fitter.setNetwrok(dbXml->m_data);
144 void SVDNNClusterizerModule::createRelationLookup(
const RelationArray& relation,
149 if (!relation)
return;
151 lookup.resize(digits);
152 for (
const auto& element : relation) {
153 lookup[element.getFromIndex()] = &element;
158 std::map<unsigned int, float>& relation,
unsigned int index)
161 if (!lookup.empty() && lookup[index]) {
163 const unsigned int size = element.getSize();
165 for (
unsigned int i = 0; i < size; ++i) {
168 if (element.getWeight(i) < 0)
continue;
169 relation[element.getToIndex(i)] += element.getWeight(i);
174 void SVDNNClusterizerModule::event()
178 if (!storeRecoDigits || !storeRecoDigits.
getEntries())
return;
180 size_t nDigits = storeRecoDigits.
getEntries();
181 B2DEBUG(90,
"Initial size of RecoDigits array: " << nDigits);
186 RelationArray relRecoDigitMCParticle(storeRecoDigits, storeMCParticles, m_relRecoDigitMCParticleName);
187 RelationArray relRecoDigitTrueHit(storeRecoDigits, storeTrueHits, m_relRecoDigitTrueHitName);
190 storeClusters.
clear();
192 RelationArray relClusterMCParticle(storeClusters, storeMCParticles,
193 m_relClusterMCParticleName);
194 if (relClusterMCParticle) relClusterMCParticle.
clear();
196 RelationArray relClusterRecoDigit(storeClusters, storeRecoDigits,
197 m_relClusterRecoDigitName);
198 if (relClusterRecoDigit) relClusterRecoDigit.
clear();
200 RelationArray relClusterTrueHit(storeClusters, storeTrueHits,
201 m_relClusterTrueHitName);
202 if (relClusterTrueHit) relClusterTrueHit.
clear();
205 createRelationLookup(relRecoDigitMCParticle, m_mcRelation, nDigits);
206 createRelationLookup(relRecoDigitTrueHit, m_trueRelation, nDigits);
212 vector<pair<unsigned short, unsigned short> > sensorDigits;
213 VxdID lastSensorID(0);
214 size_t firstSensorDigit = 0;
215 for (
size_t iDigit = 0; iDigit < nDigits; ++iDigit) {
219 if (sensorID != lastSensorID) {
220 sensorDigits.push_back(make_pair(firstSensorDigit, iDigit));
221 firstSensorDigit = iDigit;
222 lastSensorID = sensorID;
226 sensorDigits.push_back(make_pair(firstSensorDigit, nDigits));
229 for (
auto id_indices : sensorDigits) {
231 unsigned int firstDigit = id_indices.first;
232 unsigned int lastDigit = id_indices.second;
234 const SVDRecoDigit& sampleRecoDigit = *storeRecoDigits[firstDigit];
236 bool isU = sampleRecoDigit.
isUStrip();
251 vector<pair<size_t, size_t> > stripGroups;
252 size_t firstClusterDigit = firstDigit;
253 size_t lastClusterDigit = firstDigit;
254 short lastStrip = -2;
256 B2DEBUG(300,
"Clustering digits " << firstDigit <<
" to " << lastDigit);
257 for (
size_t iDigit = firstDigit; iDigit < lastDigit; ++iDigit) {
259 const SVDRecoDigit& recoDigit = *storeRecoDigits[iDigit];
260 unsigned short currentStrip = recoDigit.
getCellID();
261 B2DEBUG(300,
"Digit " << iDigit <<
", strip: " << currentStrip <<
", lastStrip: " << lastStrip);
262 B2DEBUG(300,
"First CD: " << firstClusterDigit <<
" Last CD: " << lastClusterDigit);
265 bool consecutive = ((currentStrip - lastStrip) == 1);
266 lastStrip = currentStrip;
268 B2DEBUG(300, (consecutive ?
"consecutive" :
"gap"));
271 if (!consecutive && (firstClusterDigit < lastClusterDigit)) {
272 B2DEBUG(300,
"Saving (" << firstClusterDigit <<
", " << lastClusterDigit <<
")");
273 stripGroups.emplace_back(firstClusterDigit, lastClusterDigit);
278 lastClusterDigit = iDigit + 1;
280 firstClusterDigit = iDigit;
281 lastClusterDigit = iDigit + 1;
285 if (firstClusterDigit < lastClusterDigit) {
286 B2DEBUG(300,
"Saving (" << firstClusterDigit <<
", " << lastDigit <<
")");
287 stripGroups.emplace_back(firstClusterDigit, lastDigit);
291 os <<
"StripGroups: " << endl;
292 for (
auto item : stripGroups) {
293 os <<
"(" << item.first <<
", " << item.second <<
"), ";
295 B2DEBUG(300, os.str());
304 vector<unsigned short> stripNumbers;
305 vector<float> stripPositions;
306 vector<float> stripNoises;
307 vector<float> stripGains;
308 vector<float> timeShifts;
309 vector<float> waveWidths;
310 vector<apvSamples> storedNormedSamples;
311 vector<SVDRecoDigit::OutputProbArray> storedPDFs;
314 for (
auto clusterBounds : stripGroups) {
316 unsigned short clusterSize = clusterBounds.second - clusterBounds.first;
317 assert(clusterSize > 0);
319 stripNumbers.clear();
320 stripPositions.clear();
326 for (
size_t iDigit = clusterBounds.first; iDigit < clusterBounds.second; ++iDigit) {
330 const SVDRecoDigit& recoDigit = *storeRecoDigits[iDigit];
332 unsigned short stripNo = recoDigit.
getCellID();
333 stripNumbers.push_back(stripNo);
335 double stripNoiseADU = m_noiseCal.getNoise(sensorID, isU, stripNo);
336 stripNoises.push_back(
337 m_pulseShapeCal.getChargeFromADC(sensorID, isU, stripNo, stripNoiseADU)
342 double peakWidth = 270;
343 double timeShift = isU ? 4.0 : 0.0;
344 if (m_calibratePeak) {
345 peakWidth = 1.988 * m_pulseShapeCal.getWidth(sensorID, isU, stripNo);
346 timeShift = m_pulseShapeCal.getPeakTime(sensorID, isU, stripNo)
349 waveWidths.push_back(peakWidth);
350 timeShifts.push_back(timeShift);
351 stripPositions.push_back(
358 B2FATAL(
"Missing SVDRecoDigits->SVDShaperDigits relation. This should not happen.");
360 transform(samples.begin(), samples.end(), normedSamples.begin(),
361 bind2nd(divides<float>(), stripNoiseADU));
365 zeroSuppress(normedSamples, m_cutAdjacent);
366 storedNormedSamples.emplace_back(normedSamples);
372 float clusterNoise = sqrt(
374 * inner_product(stripNoises.begin(), stripNoises.end(), stripNoises.begin(), 0.0)
376 B2DEBUG(200,
"RMS cluster noise: " << clusterNoise);
380 shared_ptr<nnFitterBinData> pStrip;
383 fill(pCluster.begin(), pCluster.end(),
double(1.0));
385 for (
size_t iClusterStrip = 0; iClusterStrip < clusterSize; ++iClusterStrip) {
386 size_t iDigit = clusterBounds.first + iClusterStrip;
388 os1 <<
"Input to NNFitter: iDigit = " << iDigit << endl <<
"Samples: ";
389 copy(storedNormedSamples[iClusterStrip].begin(), storedNormedSamples[iClusterStrip].end(),
390 ostream_iterator<double>(os1,
" "));
392 os1 <<
"PDF from RecoDigit: " << endl;
393 copy(storedPDFs[iClusterStrip].begin(), storedPDFs[iClusterStrip].end(), ostream_iterator<double>(os1,
" "));
395 fitTool.
multiply(pCluster, storedPDFs[iClusterStrip]);
396 os1 <<
"Accummulated: " << endl;
397 copy(pCluster.begin(), pCluster.end(), ostream_iterator<double>(os1,
" "));
398 B2DEBUG(200, os1.str());
401 double clusterTime, clusterTimeErr;
402 tie(clusterTime, clusterTimeErr) = fitTool.
getTimeShift(pCluster);
403 B2DEBUG(200,
"Time: " << clusterTime <<
" +/- " << clusterTimeErr);
407 vector<double> stripAmplitudes(stripNoises.size());
408 vector<double> stripAmplitudeErrors(stripNoises.size());
409 double clusterChi2 = 0.0;
410 for (
size_t iClusterStrip = 0; iClusterStrip < clusterSize; ++iClusterStrip) {
411 size_t iDigit = clusterBounds.first + iClusterStrip;
412 double snAmp, snAmpError, chi2;
413 tie(snAmp, snAmpError, chi2) =
414 fitTool.
getAmplitudeChi2(storedNormedSamples[iClusterStrip], clusterTime, waveWidths[iClusterStrip]);
416 stripAmplitudes[iClusterStrip] = stripNoises[iClusterStrip] * snAmp;
417 stripAmplitudeErrors[iClusterStrip] = stripNoises[iClusterStrip] * snAmpError;
419 B2DEBUG(200,
"Digit " << iDigit <<
" Noise: " << stripNoises[iClusterStrip]
420 <<
" Amplitude: " << stripAmplitudes[iClusterStrip]
421 <<
" +/- " << stripAmplitudeErrors[iClusterStrip]
426 float clusterCharge = accumulate(stripAmplitudes.begin(), stripAmplitudes.end(), 0.0);
427 float clusterChargeError = sqrt(
428 inner_product(stripAmplitudeErrors.begin(), stripAmplitudeErrors.end(),
429 stripAmplitudeErrors.begin(), 0.0)
431 float clusterSN = (clusterChargeError > 0) ? clusterCharge / clusterChargeError : clusterCharge;
433 clusterChi2 /= clusterSize;
435 size_t seedIndex = distance(stripAmplitudes.begin(), max_element(
436 stripAmplitudes.begin(), stripAmplitudes.end()));
437 float clusterSeedCharge = stripAmplitudes[seedIndex];
438 B2DEBUG(200,
"Cluster parameters:");
439 B2DEBUG(200,
"Charge: " << clusterCharge <<
" +/- " << clusterChargeError);
440 B2DEBUG(200,
"Seed: " << clusterSeedCharge <<
" +/- " << stripAmplitudeErrors[seedIndex]);
441 B2DEBUG(200,
"S/N: " << clusterSN);
442 B2DEBUG(200,
"chi2: " << clusterChi2);
445 float clusterPosition, clusterPositionError;
450 if ((clusterCharge < clusterNoise * m_cutCluster) || (clusterSeedCharge < clusterNoise * m_cutSeed))
453 if (clusterSize < m_sizeHeadTail) {
454 clusterPosition = 1.0 / clusterCharge * inner_product(
455 stripAmplitudes.begin(), stripAmplitudes.end(), stripPositions.begin(), 0.0
458 float phantomCharge = m_cutAdjacent * clusterNoise;
459 if (clusterSize == 1) {
460 clusterPositionError = pitch * phantomCharge / (clusterCharge + phantomCharge);
462 clusterPositionError = pitch * phantomCharge / clusterCharge;
465 float leftStripCharge = stripAmplitudes.front();
466 float leftPos = stripPositions.front();
467 float rightStripCharge = stripAmplitudes.back();
468 float rightPos = stripPositions.back();
469 float centreCharge = (clusterCharge - leftStripCharge - rightStripCharge) / (clusterSize - 2);
470 leftStripCharge = (leftStripCharge < centreCharge) ? leftStripCharge : centreCharge;
471 rightStripCharge = (rightStripCharge < centreCharge) ? rightStripCharge : centreCharge;
472 clusterPosition = 0.5 * (leftPos + rightPos)
473 + 0.5 * (rightStripCharge - leftStripCharge) / centreCharge * pitch;
474 float sn = centreCharge / m_cutAdjacent / clusterNoise;
476 float landauHead = leftStripCharge / centreCharge;
477 double landauTail = rightStripCharge / centreCharge;
478 clusterPositionError = 0.5 * pitch * sqrt(1.0 / sn / sn
479 + 0.5 * landauHead * landauHead + 0.5 * landauTail * landauTail);
485 map<unsigned int, float> mc_relations;
486 map<unsigned int, float> truehit_relations;
487 vector<pair<unsigned int, float> > digit_weights;
488 digit_weights.reserve(clusterSize);
490 for (
size_t iDigit = clusterBounds.first; iDigit < clusterBounds.second; ++iDigit) {
492 fillRelationMap(m_mcRelation, mc_relations, iDigit);
493 fillRelationMap(m_trueRelation, truehit_relations, iDigit);
495 digit_weights.emplace_back(iDigit, stripAmplitudes[iDigit - clusterBounds.first]);
500 VxdID clusterSensorID = sensorID;
503 SVDCluster(sensorID, isU, clusterPosition, clusterPositionError, clusterTime,
504 clusterTimeErr, clusterCharge, clusterSeedCharge, clusterSize, clusterSN, clusterChi2)
508 if (!mc_relations.empty()) {
509 relClusterMCParticle.
add(clsIndex, mc_relations.begin(), mc_relations.end());
511 if (!truehit_relations.empty()) {
512 relClusterTrueHit.
add(clsIndex, truehit_relations.begin(), truehit_relations.end());
514 relClusterRecoDigit.
add(clsIndex, digit_weights.begin(), digit_weights.end());
519 B2DEBUG(100,
"Number of clusters: " << storeClusters.
getEntries());