Belle II Software development
SVDNNShapeReconstructorModule.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include <svd/modules/svdReconstruction/SVDNNShapeReconstructorModule.h>
10
11#include <framework/datastore/DataStore.h>
12#include <framework/datastore/StoreArray.h>
13#include <framework/logging/Logger.h>
14
15#include <mdst/dataobjects/MCParticle.h>
16#include <svd/dataobjects/SVDTrueHit.h>
17#include <svd/dataobjects/SVDShaperDigit.h>
18#include <svd/dataobjects/SVDRecoDigit.h>
19#include <mva/dataobjects/DatabaseRepresentationOfWeightfile.h>
20
21#include <svd/reconstruction/NNWaveFitTool.h>
22
23#include <algorithm>
24#include <functional>
25
26using namespace std;
27using namespace std::placeholders;
28using namespace Belle2;
29using namespace Belle2::SVD;
30
31//-----------------------------------------------------------------
32// Register the Module
33//-----------------------------------------------------------------
34REG_MODULE(SVDNNShapeReconstructor);
35
36//-----------------------------------------------------------------
37// Implementation
38//-----------------------------------------------------------------
39
41{
42 B2DEBUG(200, "Now in SVDNNShapeReconstructorModule ctor");
43 //Set module properties
44 setDescription("Reconstruct signals on SVD strips.");
46
47 // 1. Collections.
49 "Shaperdigits collection name", string(""));
50 addParam("RecoDigits", m_storeRecoDigitsName,
51 "RecoDigits collection name", string(""));
52 addParam("TrueHits", m_storeTrueHitsName,
53 "TrueHits collection name", string(""));
54 addParam("MCParticles", m_storeMCParticlesName,
55 "MCParticles collection name", string(""));
56 addParam("WriteRecoDigits", m_writeRecoDigits,
57 "Write RecoDigits to output?", m_writeRecoDigits);
58 addParam("SVDEventInfo", m_svdEventInfoName,
59 "SVDEventInfo name", string(""));
60 // 2. Calibration and time fitter sources
61 addParam("TimeFitterName", m_timeFitterName,
62 "Name of time fitter data file", string("SVDTimeNet_6samples"));
63 addParam("CalibratePeak", m_calibratePeak, "Use calibrattion (vs. default) for peak widths and positions", bool(false));
64 // 3. Zero suppression
65 addParam("ZeroSuppressionCut", m_cutAdjacent, "Zero-suppression cut on digits",
67}
68
70{
71 //Register collections
76
79 else
80 storeRecoDigits.registerInDataStore();
81
82 storeShaperDigits.isRequired();
83 storeTrueHits.isOptional();
84 storeMCParticles.isOptional();
85 m_storeSVDEvtInfo.isRequired();
86
87 if (!m_storeSVDEvtInfo.isOptional(m_svdEventInfoName)) m_svdEventInfoName = "SVDEventInfoSim";
89
90 RelationArray relRecoDigitShaperDigits(storeRecoDigits, storeShaperDigits);
91 RelationArray relRecoDigitTrueHits(storeRecoDigits, storeTrueHits);
92 RelationArray relRecoDigitMCParticles(storeRecoDigits, storeMCParticles);
93 RelationArray relShaperDigitTrueHits(storeShaperDigits, storeTrueHits);
94 RelationArray relShaperDigitMCParticles(storeShaperDigits, storeMCParticles);
95
97 relRecoDigitShaperDigits.registerInDataStore(DataStore::c_DontWriteOut);
98 else
99 relRecoDigitShaperDigits.registerInDataStore();
100 //Relations to simulation objects only if the ancestor relations exist
101 if (relShaperDigitTrueHits.isOptional())
102 relRecoDigitTrueHits.registerInDataStore();
103 if (relShaperDigitMCParticles.isOptional())
104 relRecoDigitMCParticles.registerInDataStore();
105
106 //Store names to speed up creation later
107 m_storeRecoDigitsName = storeRecoDigits.getName();
108 m_storeShaperDigitsName = storeShaperDigits.getName();
109 m_storeTrueHitsName = storeTrueHits.getName();
110 m_storeMCParticlesName = storeMCParticles.getName();
111
112 m_relRecoDigitShaperDigitName = relRecoDigitShaperDigits.getName();
113 m_relRecoDigitTrueHitName = relRecoDigitTrueHits.getName();
114 m_relRecoDigitMCParticleName = relRecoDigitMCParticles.getName();
115 m_relShaperDigitTrueHitName = relShaperDigitTrueHits.getName();
116 m_relShaperDigitMCParticleName = relShaperDigitMCParticles.getName();
117
118 B2INFO(" 1. COLLECTIONS:");
119 B2INFO(" --> MCParticles: " << m_storeMCParticlesName);
120 B2INFO(" --> Digits: " << m_storeShaperDigitsName);
121 B2INFO(" --> RecoDigits: " << m_storeRecoDigitsName);
122 B2INFO(" --> TrueHits: " << m_storeTrueHitsName);
123 B2INFO(" --> DigitMCRel: " << m_relShaperDigitMCParticleName);
124 B2INFO(" --> RecoDigitMCRel: " << m_relRecoDigitMCParticleName);
125 B2INFO(" --> RecoDigitDigitRel: " << m_relRecoDigitShaperDigitName);
126 B2INFO(" --> DigitTrueRel: " << m_relShaperDigitTrueHitName);
127 B2INFO(" --> RecoDigitTrueRel: " << m_relRecoDigitTrueHitName);
128 B2INFO(" --> Save RecoDigits? " << (m_writeRecoDigits ? "Y" : "N"));
129 B2INFO(" 2. CALIBRATION:");
130 B2INFO(" --> Time NN: " << m_timeFitterName);
131
132 // Properly initialize the NN time fitter
133 // FIXME: Should be moved to beginRun
134 // FIXME: No support for 3/6 sample switching within a run/event
136 m_fitter.setNetwrok(dbXml->m_data);
137}
138
140 RelationLookup& lookup, size_t digits)
141{
142 lookup.clear();
143 //If we don't have a relation we don't build a lookuptable
144 if (!relation) return;
145 //Resize to number of digits and set all values
146 lookup.resize(digits);
147 for (const auto& element : relation) {
148 lookup[element.getFromIndex()] = &element;
149 }
150}
151
153 std::map<unsigned int, float>& relation, unsigned int index)
154{
155 //If the lookup table is not empty and the element is set
156 if (!lookup.empty() && lookup[index]) {
157 const RelationElement& element = *lookup[index];
158 const unsigned int size = element.getSize();
159 //Add all Relations to the map
160 for (unsigned int i = 0; i < size; ++i) {
161 //negative weights are from ignored particles, we don't like them and
162 //thus ignore them :D
163 if (element.getWeight(i) < 0) continue;
164 relation[element.getToIndex(i)] += element.getWeight(i);
165 }
166 }
167}
168
170{
171
173 // If no digits or no SVDEventInfo, nothing to do
174 if (!storeShaperDigits || !storeShaperDigits.getEntries() || !m_storeSVDEvtInfo.isValid()) return;
175
176 SVDModeByte modeByte = m_storeSVDEvtInfo->getModeByte();
177
178 size_t nDigits = storeShaperDigits.getEntries();
179 B2DEBUG(90, "Initial size of StoreDigits array: " << nDigits);
180
181 const StoreArray<MCParticle> storeMCParticles(m_storeMCParticlesName);
182 const StoreArray<SVDTrueHit> storeTrueHits(m_storeTrueHitsName);
183
184 RelationArray relShaperDigitMCParticle(storeShaperDigits, storeMCParticles, m_relShaperDigitMCParticleName);
185 RelationArray relShaperDigitTrueHit(storeShaperDigits, storeTrueHits, m_relShaperDigitTrueHitName);
186
188 storeRecoDigits.clear();
189
190 RelationArray relRecoDigitMCParticle(storeRecoDigits, storeMCParticles,
192 if (relRecoDigitMCParticle) relRecoDigitMCParticle.clear();
193
194 RelationArray relRecoDigitShaperDigit(storeRecoDigits, storeShaperDigits,
196 if (relRecoDigitShaperDigit) relRecoDigitShaperDigit.clear();
197
198 RelationArray relRecoDigitTrueHit(storeRecoDigits, storeTrueHits,
200 if (relRecoDigitTrueHit) relRecoDigitTrueHit.clear();
201
202 //Build lookup tables for relations
203 createRelationLookup(relShaperDigitMCParticle, m_mcRelation, nDigits);
204 createRelationLookup(relShaperDigitTrueHit, m_trueRelation, nDigits);
205
206 // Create fit tool object
208
209 // I. Group digits by sensor/side.
210 vector<pair<unsigned short, unsigned short> > sensorDigits;
211 VxdID lastSensorID(0);
212 size_t firstSensorDigit = 0;
213 for (size_t iDigit = 0; iDigit < nDigits; ++iDigit) {
214 const SVDShaperDigit& digit = *storeShaperDigits[iDigit];
215 VxdID sensorID = digit.getSensorID();
216 sensorID.setSegmentNumber(digit.isUStrip() ? 1 : 0);
217 if (sensorID != lastSensorID) { // we have a new sensor side
218 sensorDigits.push_back(make_pair(firstSensorDigit, iDigit));
219 firstSensorDigit = iDigit;
220 lastSensorID = sensorID;
221 }
222 }
223 // save last VxdID
224 sensorDigits.push_back(make_pair(firstSensorDigit, nDigits));
225
226 // ICYCLE OVER SENSORS
227 for (auto id_indices : sensorDigits) {
228 // Retrieve parameters from sensorDigits
229 unsigned int firstDigit = id_indices.first;
230 unsigned int lastDigit = id_indices.second;
231 // Get VXDID and side from the first digit
232 const SVDShaperDigit& exampleDigit = *storeShaperDigits[firstDigit];
233 VxdID sensorID = exampleDigit.getSensorID();
234 bool isU = exampleDigit.isUStrip();
235
236 // 2. Cycle through digits and form recodigits on the way.
237
238 B2DEBUG(300, "Reconstructing digits " << firstDigit << " to " << lastDigit);
239 for (size_t iDigit = firstDigit; iDigit < lastDigit; ++iDigit) {
240
241 const SVDShaperDigit& shaperDigit = *storeShaperDigits[iDigit];
242 unsigned short stripNo = shaperDigit.getCellID();
243 bool validDigit = true; // FIXME: We don't care about local run bad strips for now.
244 const double triggerBinSep = 4 * 1.96516; //in ns
245 double apvPhase = triggerBinSep * (0.5 + static_cast<int>(modeByte.getTriggerBin()));
246 // Get things from the database.
247 // Noise is good as it comes.
248 float stripNoiseADU = m_noiseCal.getNoise(sensorID, isU, stripNo);
249 // Some calibrations magic.
250 // FIXME: Only use calibration on real data. Until simulations correspond to
251 // default calibrtion, we cannot use it.
252 double stripSignalWidth = 270;
253 double stripT0 = isU ? 2.5 : -2.2;
254 if (m_calibratePeak) {
255 stripSignalWidth = 1.988 * m_pulseShapeCal.getWidth(sensorID, isU, stripNo);
256 stripT0 = m_pulseShapeCal.getPeakTime(sensorID, isU, stripNo)
257 - 0.25 * stripSignalWidth;
258 }
259
260 B2DEBUG(300, "Strip parameters: stripNoiseADU: " << stripNoiseADU <<
261 " Width: " << stripSignalWidth <<
262 " T0: " << stripT0);
263
264 // If the strip is not masked away, normalize samples (sample/stripNoise)
265 apvSamples normedSamples;
266 auto samples = shaperDigit.getSamples();
267 transform(samples.begin(), samples.end(), normedSamples.begin(),
268 bind(divides<float>(), _1, stripNoiseADU));
269 // FIXME: This won't work in 3 sample mode, we have no control over the number of non-zero samples.
270 validDigit = validDigit && pass3Samples(normedSamples, m_cutAdjacent);
271
272 if (validDigit) {
273 zeroSuppress(normedSamples, m_cutAdjacent);
274 } else // only now we give up on the diigit
275 continue;
276
277 // 3. Now we create and save the RecoDigit
278
279 ostringstream os;
280 os << "Input to NNFitter: iDigit = " << iDigit << endl << "Samples: ";
281 copy(normedSamples.begin(), normedSamples.end(), ostream_iterator<double>(os, " "));
282 os << endl;
283 std::shared_ptr<nnFitterBinData> pStrip = m_fitter.getFit(normedSamples, stripSignalWidth);
284 os << "Output from NNWaveFitter: " << endl;
285 copy(pStrip->begin(), pStrip->end(), ostream_iterator<double>(os, " "));
286 os << endl;
287 // Apply strip time shift to pdf
288 fitTool.shiftInTime(*pStrip, -apvPhase - stripT0);
289 B2DEBUG(200, os.str());
290 // Calculate time and its error, amplitude and its error, and chi2
291 double stripTime, stripTimeError;
292 tie(stripTime, stripTimeError) = fitTool.getTimeShift(*pStrip);
293 // Now we have the cluster time pdf, so we can calculate amplitudes.
294 double stripAmplitude, stripAmplitudeError, stripChi2;
295 tie(stripAmplitude, stripAmplitudeError, stripChi2) =
296 fitTool.getAmplitudeChi2(normedSamples, stripTime, stripSignalWidth);
297 //De-normalize amplitudes and convert to electrons.
298 stripAmplitude = m_pulseShapeCal.getChargeFromADC(sensorID, isU, stripNo, stripAmplitude * stripNoiseADU);
299 stripAmplitudeError = m_pulseShapeCal.getChargeFromADC(sensorID, isU, stripNo, stripAmplitudeError * stripNoiseADU);
300 B2DEBUG(200, "RecoDigit " << iDigit << " Noise: " << m_pulseShapeCal.getChargeFromADC(sensorID, isU, stripNo, stripNoiseADU)
301 << " Time: " << stripTime << " +/- " << stripTimeError
302 << " Amplitude: " << stripAmplitude << " +/- " << stripAmplitudeError
303 << " Chi2: " << stripChi2
304 );
305
306 // Finally, we save the RecoDigit and its relations.
307 map<unsigned int, float> mc_relations;
308 map<unsigned int, float> truehit_relations;
309 vector<pair<unsigned int, float> > digit_weights;
310 digit_weights.reserve(1);
311
312 // Store relations to MCParticles and SVDTrueHits
313 fillRelationMap(m_mcRelation, mc_relations, iDigit);
314 fillRelationMap(m_trueRelation, truehit_relations, iDigit);
315 //Add digit to the RecoDigit->ShaperDigit relation list
316 digit_weights.emplace_back(iDigit, 1.0);
317
318 //Store the RecoDigit into Datastore ...
319 int recoDigitIndex = storeRecoDigits.getEntries();
320 storeRecoDigits.appendNew(
321 SVDRecoDigit(sensorID, isU, shaperDigit.getCellID(), stripAmplitude,
322 stripAmplitudeError, stripTime, stripTimeError, *pStrip, stripChi2)
323 );
324
325 //Create relations to RecoDigits
326 if (!mc_relations.empty()) {
327 relRecoDigitMCParticle.add(recoDigitIndex, mc_relations.begin(), mc_relations.end());
328 }
329 if (!truehit_relations.empty()) {
330 relRecoDigitTrueHit.add(recoDigitIndex, truehit_relations.begin(), truehit_relations.end());
331 }
332 relRecoDigitShaperDigit.add(recoDigitIndex, digit_weights.begin(), digit_weights.end());
333 } // CYCLE OVER SHAPERDIGITS
334
335 } // CYCLE OVER SENSORS for items in sensorDigits
336
337 B2DEBUG(100, "Number of RecoDigits: " << storeRecoDigits.getEntries());
338
339} // event()
340
341
Class for accessing objects in the database.
Definition: DBObjPtr.h:21
@ c_DontWriteOut
Object/array should be NOT saved by output modules.
Definition: DataStore.h:71
Base class for Modules.
Definition: Module.h:72
void setDescription(const std::string &description)
Sets the description of the module.
Definition: Module.cc:214
void setPropertyFlags(unsigned int propertyFlags)
Sets the flags for the module properties.
Definition: Module.cc:208
@ c_ParallelProcessingCertified
This module can be run in parallel processing mode safely (All I/O must be done through the data stor...
Definition: Module.h:80
Low-level class to create/modify relations between StoreArrays.
Definition: RelationArray.h:62
void add(index_type from, index_type to, weight_type weight=1.0)
Add a new element to the relation.
void clear() override
Clear all elements from the relation.
Class to store a single element of a relation.
Class to store SVD mode information.
Definition: SVDModeByte.h:69
baseType getTriggerBin() const
Get the triggerBin id.
Definition: SVDModeByte.h:140
float getNoise(const VxdID &sensorID, const bool &isU, const unsigned short &strip) const
This is the method for getting the noise.
double getChargeFromADC(const Belle2::VxdID &sensorID, const bool &isU, const unsigned short &strip, const double &pulseADC) const
Return the charge (number of electrons/holes) collected on a specific strip, given the number of ADC ...
float getPeakTime(const VxdID &sensorID, const bool &isU, const unsigned short &strip) const
Return the peaking time of the strip.
float getWidth(const VxdID &sensorID, const bool &isU, const unsigned short &strip) const
Return the width of the pulse shape for a given strip.
The SVD RecoDigit class.
Definition: SVDRecoDigit.h:43
The SVD ShaperDigit class.
VxdID getSensorID() const
Get the sensor ID.
APVFloatSamples getSamples() const
Get array of samples.
short int getCellID() const
Get strip ID.
bool isUStrip() const
Get strip direction.
The class holds arrays of bins and bin centers, and a wave generator object containing information on...
Definition: NNWaveFitTool.h:91
std::tuple< double, double, double > getAmplitudeChi2(const apvSamples &samples, double timeShift, double tau)
Return std::tuple with signal amplitude, its error, and chi2 of the fit.
std::tuple< double, double > getTimeShift(const nnFitterBinData &p)
Return std::tuple containing time shift and its error.
void shiftInTime(nnFitterBinData &p, double timeShift)
Shift the probability array in time.
void setNetwrok(const std::string &xmlData)
Set proper network definition file.
const NNWaveFitTool & getFitTool() const
Get a handle to a NNWaveFit object.
Definition: NNWaveFitter.h:98
std::shared_ptr< nnFitterBinData > getFit(const apvSamples &samples, double tau)
Fitting method Send data and get rseult structure.
std::string m_storeRecoDigitsName
Name of the collection to use for the SVDRecoDigits.
std::string m_relShaperDigitMCParticleName
Name of the relation between SVDShaperDigits and MCParticles.
bool m_writeRecoDigits
Write SVDRecoDigits? (no in normal operation)
virtual void initialize() override
Initialize the module.
SVDNNShapeReconstructorModule()
Constructor defining the parameters.
std::string m_storeShaperDigitsName
Name of the collection to use for the SVDShaperDigits.
virtual void event() override
do the clustering
std::vector< const RelationElement * > RelationLookup
Container for a RelationArray Lookup table.
std::string m_relRecoDigitTrueHitName
Name of the relation between SVDRecoDigits and SVDTrueHits.
std::string m_storeTrueHitsName
Name of the collection to use for the SVDTrueHits.
std::string m_relRecoDigitMCParticleName
Name of the relation between SVDRecoDigits and MCParticles.
std::string m_timeFitterName
Name of the time fitter data xml.
void fillRelationMap(const RelationLookup &lookup, std::map< unsigned int, float > &relation, unsigned int index)
Add the relation from a given SVDShaperDigit index to a map.
std::string m_storeMCParticlesName
Name of the collection to use for the MCParticles.
std::string m_relShaperDigitTrueHitName
Name of the relation between SVDShaperDigits and SVDTrueHits.
RelationLookup m_trueRelation
Lookup table for SVDShaperDigit->SVDTrueHit relation.
SVDPulseShapeCalibrations m_pulseShapeCal
Calibrations: pusle shape and gain.
std::string m_svdEventInfoName
Name of the SVDEventInfo object.
SVDNoiseCalibrations m_noiseCal
Calibrations: noise.
void createRelationLookup(const RelationArray &relation, RelationLookup &lookup, size_t digits)
Create lookup maps for relations FIXME: This has to be significantly simplified here,...
RelationLookup m_mcRelation
Lookup table for SVDShaperDigit->MCParticle relation.
StoreObjPtr< SVDEventInfo > m_storeSVDEvtInfo
Storage for SVDEventInfo object.
bool m_calibratePeak
Use peak widths and peak time calibrations? Unitl this is also simulated, set to true only for testbe...
std::string m_relRecoDigitShaperDigitName
Name of the relation between SVDRecoDigits and SVDShaperDigits.
const std::string & getName() const
Return name under which the object is saved in the DataStore.
bool isRequired(const std::string &name="")
Ensure this array/object has been registered previously.
bool isOptional(const std::string &name="")
Tell the DataStore about an optional input.
bool registerInDataStore(DataStore::EStoreFlags storeFlags=DataStore::c_WriteOut)
Register the object/array in the DataStore.
Accessor to arrays stored in the data store.
Definition: StoreArray.h:113
T * appendNew()
Construct a new T object at the end of the array.
Definition: StoreArray.h:246
int getEntries() const
Get the number of objects in the array.
Definition: StoreArray.h:216
void clear() override
Delete all entries in this array.
Definition: StoreArray.h:207
Class to uniquely identify a any structure of the PXD and SVD.
Definition: VxdID.h:33
void setSegmentNumber(baseType segment)
Set the sensor segment.
Definition: VxdID.h:113
void addParam(const std::string &name, T &paramVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module.
Definition: Module.h:560
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
Definition: Module.h:650
Namespace to encapsulate code needed for simulation and reconstrucion of the SVD.
Definition: GeoSVDCreator.h:23
std::array< apvSampleBaseType, nAPVSamples > apvSamples
vector od apvSample BaseType objects
void zeroSuppress(T &a, double thr)
pass zero suppression
bool pass3Samples(const T &a, double thr)
pass 3-samples
Abstract base class for different kinds of events.
STL namespace.