11 #include <svd/modules/svdReconstruction/SVDNNShapeReconstructorModule.h>
13 #include <framework/datastore/DataStore.h>
14 #include <framework/datastore/StoreArray.h>
15 #include <framework/logging/Logger.h>
17 #include <mdst/dataobjects/MCParticle.h>
18 #include <svd/dataobjects/SVDTrueHit.h>
19 #include <svd/dataobjects/SVDShaperDigit.h>
20 #include <svd/dataobjects/SVDRecoDigit.h>
21 #include <mva/dataobjects/DatabaseRepresentationOfWeightfile.h>
23 #include <svd/reconstruction/NNWaveFitTool.h>
43 B2DEBUG(200,
"Now in SVDNNShapeReconstructorModule ctor");
45 setDescription(
"Reconstruct signals on SVD strips.");
46 setPropertyFlags(c_ParallelProcessingCertified);
49 addParam(
"Digits", m_storeShaperDigitsName,
50 "Shaperdigits collection name",
string(
""));
51 addParam(
"RecoDigits", m_storeRecoDigitsName,
52 "RecoDigits collection name",
string(
""));
53 addParam(
"TrueHits", m_storeTrueHitsName,
54 "TrueHits collection name",
string(
""));
55 addParam(
"MCParticles", m_storeMCParticlesName,
56 "MCParticles collection name",
string(
""));
57 addParam(
"WriteRecoDigits", m_writeRecoDigits,
58 "Write RecoDigits to output?", m_writeRecoDigits);
59 addParam(
"SVDEventInfo", m_svdEventInfoName,
60 "SVDEventInfo name",
string(
""));
62 addParam(
"TimeFitterName", m_timeFitterName,
63 "Name of time fitter data file",
string(
"SVDTimeNet_6samples"));
64 addParam(
"CalibratePeak", m_calibratePeak,
"Use calibrattion (vs. default) for peak widths and positions",
bool(
false));
66 addParam(
"ZeroSuppressionCut", m_cutAdjacent,
"Zero-suppression cut on digits",
70 void SVDNNShapeReconstructorModule::initialize()
78 if (!m_writeRecoDigits)
79 storeRecoDigits.registerInDataStore(DataStore::c_DontWriteOut);
81 storeRecoDigits.registerInDataStore();
83 storeShaperDigits.isRequired();
84 storeTrueHits.isOptional();
85 storeMCParticles.isOptional();
86 m_storeSVDEvtInfo.isRequired();
88 if (!m_storeSVDEvtInfo.isOptional(m_svdEventInfoName)) m_svdEventInfoName =
"SVDEventInfoSim";
89 m_storeSVDEvtInfo.isRequired(m_svdEventInfoName);
91 RelationArray relRecoDigitShaperDigits(storeRecoDigits, storeShaperDigits);
92 RelationArray relRecoDigitTrueHits(storeRecoDigits, storeTrueHits);
93 RelationArray relRecoDigitMCParticles(storeRecoDigits, storeMCParticles);
94 RelationArray relShaperDigitTrueHits(storeShaperDigits, storeTrueHits);
95 RelationArray relShaperDigitMCParticles(storeShaperDigits, storeMCParticles);
97 if (!m_writeRecoDigits)
108 m_storeRecoDigitsName = storeRecoDigits.getName();
109 m_storeShaperDigitsName = storeShaperDigits.getName();
110 m_storeTrueHitsName = storeTrueHits.getName();
111 m_storeMCParticlesName = storeMCParticles.getName();
113 m_relRecoDigitShaperDigitName = relRecoDigitShaperDigits.
getName();
114 m_relRecoDigitTrueHitName = relRecoDigitTrueHits.
getName();
115 m_relRecoDigitMCParticleName = relRecoDigitMCParticles.
getName();
116 m_relShaperDigitTrueHitName = relShaperDigitTrueHits.
getName();
117 m_relShaperDigitMCParticleName = relShaperDigitMCParticles.
getName();
119 B2INFO(
" 1. COLLECTIONS:");
120 B2INFO(
" --> MCParticles: " << m_storeMCParticlesName);
121 B2INFO(
" --> Digits: " << m_storeShaperDigitsName);
122 B2INFO(
" --> RecoDigits: " << m_storeRecoDigitsName);
123 B2INFO(
" --> TrueHits: " << m_storeTrueHitsName);
124 B2INFO(
" --> DigitMCRel: " << m_relShaperDigitMCParticleName);
125 B2INFO(
" --> RecoDigitMCRel: " << m_relRecoDigitMCParticleName);
126 B2INFO(
" --> RecoDigitDigitRel: " << m_relRecoDigitShaperDigitName);
127 B2INFO(
" --> DigitTrueRel: " << m_relShaperDigitTrueHitName);
128 B2INFO(
" --> RecoDigitTrueRel: " << m_relRecoDigitTrueHitName);
129 B2INFO(
" --> Save RecoDigits? " << (m_writeRecoDigits ?
"Y" :
"N"));
130 B2INFO(
" 2. CALIBRATION:");
131 B2INFO(
" --> Time NN: " << m_timeFitterName);
137 m_fitter.setNetwrok(dbXml->m_data);
140 void SVDNNShapeReconstructorModule::createRelationLookup(
const RelationArray& relation,
145 if (!relation)
return;
147 lookup.resize(digits);
148 for (
const auto& element : relation) {
149 lookup[element.getFromIndex()] = &element;
153 void SVDNNShapeReconstructorModule::fillRelationMap(
const RelationLookup& lookup,
154 std::map<unsigned int, float>& relation,
unsigned int index)
157 if (!lookup.empty() && lookup[index]) {
159 const unsigned int size = element.getSize();
161 for (
unsigned int i = 0; i < size; ++i) {
164 if (element.getWeight(i) < 0)
continue;
165 relation[element.getToIndex(i)] += element.getWeight(i);
170 void SVDNNShapeReconstructorModule::event()
175 if (!storeShaperDigits || !storeShaperDigits.
getEntries() || !m_storeSVDEvtInfo.isValid())
return;
177 SVDModeByte modeByte = m_storeSVDEvtInfo->getModeByte();
179 size_t nDigits = storeShaperDigits.
getEntries();
180 B2DEBUG(90,
"Initial size of StoreDigits array: " << nDigits);
185 RelationArray relShaperDigitMCParticle(storeShaperDigits, storeMCParticles, m_relShaperDigitMCParticleName);
186 RelationArray relShaperDigitTrueHit(storeShaperDigits, storeTrueHits, m_relShaperDigitTrueHitName);
189 storeRecoDigits.
clear();
191 RelationArray relRecoDigitMCParticle(storeRecoDigits, storeMCParticles,
192 m_relRecoDigitMCParticleName);
193 if (relRecoDigitMCParticle) relRecoDigitMCParticle.
clear();
195 RelationArray relRecoDigitShaperDigit(storeRecoDigits, storeShaperDigits,
196 m_relRecoDigitShaperDigitName);
197 if (relRecoDigitShaperDigit) relRecoDigitShaperDigit.
clear();
199 RelationArray relRecoDigitTrueHit(storeRecoDigits, storeTrueHits,
200 m_relRecoDigitTrueHitName);
201 if (relRecoDigitTrueHit) relRecoDigitTrueHit.
clear();
204 createRelationLookup(relShaperDigitMCParticle, m_mcRelation, nDigits);
205 createRelationLookup(relShaperDigitTrueHit, m_trueRelation, nDigits);
211 vector<pair<unsigned short, unsigned short> > sensorDigits;
212 VxdID lastSensorID(0);
213 size_t firstSensorDigit = 0;
214 for (
size_t iDigit = 0; iDigit < nDigits; ++iDigit) {
218 if (sensorID != lastSensorID) {
219 sensorDigits.push_back(make_pair(firstSensorDigit, iDigit));
220 firstSensorDigit = iDigit;
221 lastSensorID = sensorID;
225 sensorDigits.push_back(make_pair(firstSensorDigit, nDigits));
228 for (
auto id_indices : sensorDigits) {
230 unsigned int firstDigit = id_indices.first;
231 unsigned int lastDigit = id_indices.second;
233 const SVDShaperDigit& exampleDigit = *storeShaperDigits[firstDigit];
239 B2DEBUG(300,
"Reconstructing digits " << firstDigit <<
" to " << lastDigit);
240 for (
size_t iDigit = firstDigit; iDigit < lastDigit; ++iDigit) {
243 unsigned short stripNo = shaperDigit.
getCellID();
244 bool validDigit =
true;
245 const double triggerBinSep = 4 * 1.96516;
246 double apvPhase = triggerBinSep * (0.5 +
static_cast<int>(modeByte.
getTriggerBin()));
249 float stripNoiseADU = m_noiseCal.getNoise(sensorID, isU, stripNo);
253 double stripSignalWidth = 270;
254 double stripT0 = isU ? 2.5 : -2.2;
255 if (m_calibratePeak) {
256 stripSignalWidth = 1.988 * m_pulseShapeCal.getWidth(sensorID, isU, stripNo);
257 stripT0 = m_pulseShapeCal.getPeakTime(sensorID, isU, stripNo)
258 - 0.25 * stripSignalWidth;
261 B2DEBUG(300,
"Strip parameters: stripNoiseADU: " << stripNoiseADU <<
262 " Width: " << stripSignalWidth <<
269 transform(samples.begin(), samples.end(), normedSamples.begin(),
270 bind2nd(divides<float>(), stripNoiseADU));
272 validDigit = validDigit &&
pass3Samples(normedSamples, m_cutAdjacent);
276 zeroSuppress(normedSamples, m_cutAdjacent);
283 os <<
"Input to NNFitter: iDigit = " << iDigit << endl <<
"Samples: ";
284 copy(normedSamples.begin(), normedSamples.end(), ostream_iterator<double>(os,
" "));
286 std::shared_ptr<nnFitterBinData> pStrip = m_fitter.getFit(normedSamples, stripSignalWidth);
287 os <<
"Output from NNWaveFitter: " << endl;
288 copy(pStrip->begin(), pStrip->end(), ostream_iterator<double>(os,
" "));
292 B2DEBUG(200, os.str());
294 double stripTime, stripTimeError;
295 tie(stripTime, stripTimeError) = fitTool.
getTimeShift(*pStrip);
297 double stripAmplitude, stripAmplitudeError, stripChi2;
298 tie(stripAmplitude, stripAmplitudeError, stripChi2) =
301 stripAmplitude = m_pulseShapeCal.getChargeFromADC(sensorID, isU, stripNo, stripAmplitude * stripNoiseADU);
302 stripAmplitudeError = m_pulseShapeCal.getChargeFromADC(sensorID, isU, stripNo, stripAmplitudeError * stripNoiseADU);
303 B2DEBUG(200,
"RecoDigit " << iDigit <<
" Noise: " << m_pulseShapeCal.getChargeFromADC(sensorID, isU, stripNo, stripNoiseADU)
304 <<
" Time: " << stripTime <<
" +/- " << stripTimeError
305 <<
" Amplitude: " << stripAmplitude <<
" +/- " << stripAmplitudeError
306 <<
" Chi2: " << stripChi2
310 map<unsigned int, float> mc_relations;
311 map<unsigned int, float> truehit_relations;
312 vector<pair<unsigned int, float> > digit_weights;
313 digit_weights.reserve(1);
316 fillRelationMap(m_mcRelation, mc_relations, iDigit);
317 fillRelationMap(m_trueRelation, truehit_relations, iDigit);
319 digit_weights.emplace_back(iDigit, 1.0);
322 int recoDigitIndex = storeRecoDigits.
getEntries();
325 stripAmplitudeError, stripTime, stripTimeError, *pStrip, stripChi2,
330 if (!mc_relations.empty()) {
331 relRecoDigitMCParticle.
add(recoDigitIndex, mc_relations.begin(), mc_relations.end());
333 if (!truehit_relations.empty()) {
334 relRecoDigitTrueHit.
add(recoDigitIndex, truehit_relations.begin(), truehit_relations.end());
336 relRecoDigitShaperDigit.
add(recoDigitIndex, digit_weights.begin(), digit_weights.end());
341 B2DEBUG(100,
"Number of RecoDigits: " << storeRecoDigits.
getEntries());