Belle II Software  release-08-01-10
SVDNNShapeReconstructorModule.cc
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 #include <svd/modules/svdReconstruction/SVDNNShapeReconstructorModule.h>
10 
11 #include <framework/datastore/DataStore.h>
12 #include <framework/datastore/StoreArray.h>
13 #include <framework/logging/Logger.h>
14 
15 #include <mdst/dataobjects/MCParticle.h>
16 #include <svd/dataobjects/SVDTrueHit.h>
17 #include <svd/dataobjects/SVDShaperDigit.h>
18 #include <svd/dataobjects/SVDRecoDigit.h>
19 #include <mva/dataobjects/DatabaseRepresentationOfWeightfile.h>
20 
21 #include <svd/reconstruction/NNWaveFitTool.h>
22 
23 #include <algorithm>
24 #include <functional>
25 
26 using namespace std;
27 using namespace Belle2;
28 using namespace Belle2::SVD;
29 
30 //-----------------------------------------------------------------
31 // Register the Module
32 //-----------------------------------------------------------------
33 REG_MODULE(SVDNNShapeReconstructor);
34 
35 //-----------------------------------------------------------------
36 // Implementation
37 //-----------------------------------------------------------------
38 
39 SVDNNShapeReconstructorModule::SVDNNShapeReconstructorModule() : Module()
40 {
41  B2DEBUG(200, "Now in SVDNNShapeReconstructorModule ctor");
42  //Set module properties
43  setDescription("Reconstruct signals on SVD strips.");
45 
46  // 1. Collections.
48  "Shaperdigits collection name", string(""));
49  addParam("RecoDigits", m_storeRecoDigitsName,
50  "RecoDigits collection name", string(""));
51  addParam("TrueHits", m_storeTrueHitsName,
52  "TrueHits collection name", string(""));
53  addParam("MCParticles", m_storeMCParticlesName,
54  "MCParticles collection name", string(""));
55  addParam("WriteRecoDigits", m_writeRecoDigits,
56  "Write RecoDigits to output?", m_writeRecoDigits);
57  addParam("SVDEventInfo", m_svdEventInfoName,
58  "SVDEventInfo name", string(""));
59  // 2. Calibration and time fitter sources
60  addParam("TimeFitterName", m_timeFitterName,
61  "Name of time fitter data file", string("SVDTimeNet_6samples"));
62  addParam("CalibratePeak", m_calibratePeak, "Use calibrattion (vs. default) for peak widths and positions", bool(false));
63  // 3. Zero suppression
64  addParam("ZeroSuppressionCut", m_cutAdjacent, "Zero-suppression cut on digits",
66 }
67 
69 {
70  //Register collections
75 
76  if (!m_writeRecoDigits)
78  else
79  storeRecoDigits.registerInDataStore();
80 
81  storeShaperDigits.isRequired();
82  storeTrueHits.isOptional();
83  storeMCParticles.isOptional();
84  m_storeSVDEvtInfo.isRequired();
85 
86  if (!m_storeSVDEvtInfo.isOptional(m_svdEventInfoName)) m_svdEventInfoName = "SVDEventInfoSim";
88 
89  RelationArray relRecoDigitShaperDigits(storeRecoDigits, storeShaperDigits);
90  RelationArray relRecoDigitTrueHits(storeRecoDigits, storeTrueHits);
91  RelationArray relRecoDigitMCParticles(storeRecoDigits, storeMCParticles);
92  RelationArray relShaperDigitTrueHits(storeShaperDigits, storeTrueHits);
93  RelationArray relShaperDigitMCParticles(storeShaperDigits, storeMCParticles);
94 
95  if (!m_writeRecoDigits)
96  relRecoDigitShaperDigits.registerInDataStore(DataStore::c_DontWriteOut);
97  else
98  relRecoDigitShaperDigits.registerInDataStore();
99  //Relations to simulation objects only if the ancestor relations exist
100  if (relShaperDigitTrueHits.isOptional())
101  relRecoDigitTrueHits.registerInDataStore();
102  if (relShaperDigitMCParticles.isOptional())
103  relRecoDigitMCParticles.registerInDataStore();
104 
105  //Store names to speed up creation later
106  m_storeRecoDigitsName = storeRecoDigits.getName();
107  m_storeShaperDigitsName = storeShaperDigits.getName();
108  m_storeTrueHitsName = storeTrueHits.getName();
109  m_storeMCParticlesName = storeMCParticles.getName();
110 
111  m_relRecoDigitShaperDigitName = relRecoDigitShaperDigits.getName();
112  m_relRecoDigitTrueHitName = relRecoDigitTrueHits.getName();
113  m_relRecoDigitMCParticleName = relRecoDigitMCParticles.getName();
114  m_relShaperDigitTrueHitName = relShaperDigitTrueHits.getName();
115  m_relShaperDigitMCParticleName = relShaperDigitMCParticles.getName();
116 
117  B2INFO(" 1. COLLECTIONS:");
118  B2INFO(" --> MCParticles: " << m_storeMCParticlesName);
119  B2INFO(" --> Digits: " << m_storeShaperDigitsName);
120  B2INFO(" --> RecoDigits: " << m_storeRecoDigitsName);
121  B2INFO(" --> TrueHits: " << m_storeTrueHitsName);
122  B2INFO(" --> DigitMCRel: " << m_relShaperDigitMCParticleName);
123  B2INFO(" --> RecoDigitMCRel: " << m_relRecoDigitMCParticleName);
124  B2INFO(" --> RecoDigitDigitRel: " << m_relRecoDigitShaperDigitName);
125  B2INFO(" --> DigitTrueRel: " << m_relShaperDigitTrueHitName);
126  B2INFO(" --> RecoDigitTrueRel: " << m_relRecoDigitTrueHitName);
127  B2INFO(" --> Save RecoDigits? " << (m_writeRecoDigits ? "Y" : "N"));
128  B2INFO(" 2. CALIBRATION:");
129  B2INFO(" --> Time NN: " << m_timeFitterName);
130 
131  // Properly initialize the NN time fitter
132  // FIXME: Should be moved to beginRun
133  // FIXME: No support for 3/6 sample switching within a run/event
135  m_fitter.setNetwrok(dbXml->m_data);
136 }
137 
139  RelationLookup& lookup, size_t digits)
140 {
141  lookup.clear();
142  //If we don't have a relation we don't build a lookuptable
143  if (!relation) return;
144  //Resize to number of digits and set all values
145  lookup.resize(digits);
146  for (const auto& element : relation) {
147  lookup[element.getFromIndex()] = &element;
148  }
149 }
150 
152  std::map<unsigned int, float>& relation, unsigned int index)
153 {
154  //If the lookup table is not empty and the element is set
155  if (!lookup.empty() && lookup[index]) {
156  const RelationElement& element = *lookup[index];
157  const unsigned int size = element.getSize();
158  //Add all Relations to the map
159  for (unsigned int i = 0; i < size; ++i) {
160  //negative weights are from ignored particles, we don't like them and
161  //thus ignore them :D
162  if (element.getWeight(i) < 0) continue;
163  relation[element.getToIndex(i)] += element.getWeight(i);
164  }
165  }
166 }
167 
169 {
170 
171  const StoreArray<SVDShaperDigit> storeShaperDigits(m_storeShaperDigitsName);
172  // If no digits or no SVDEventInfo, nothing to do
173  if (!storeShaperDigits || !storeShaperDigits.getEntries() || !m_storeSVDEvtInfo.isValid()) return;
174 
175  SVDModeByte modeByte = m_storeSVDEvtInfo->getModeByte();
176 
177  size_t nDigits = storeShaperDigits.getEntries();
178  B2DEBUG(90, "Initial size of StoreDigits array: " << nDigits);
179 
180  const StoreArray<MCParticle> storeMCParticles(m_storeMCParticlesName);
181  const StoreArray<SVDTrueHit> storeTrueHits(m_storeTrueHitsName);
182 
183  RelationArray relShaperDigitMCParticle(storeShaperDigits, storeMCParticles, m_relShaperDigitMCParticleName);
184  RelationArray relShaperDigitTrueHit(storeShaperDigits, storeTrueHits, m_relShaperDigitTrueHitName);
185 
187  storeRecoDigits.clear();
188 
189  RelationArray relRecoDigitMCParticle(storeRecoDigits, storeMCParticles,
191  if (relRecoDigitMCParticle) relRecoDigitMCParticle.clear();
192 
193  RelationArray relRecoDigitShaperDigit(storeRecoDigits, storeShaperDigits,
195  if (relRecoDigitShaperDigit) relRecoDigitShaperDigit.clear();
196 
197  RelationArray relRecoDigitTrueHit(storeRecoDigits, storeTrueHits,
199  if (relRecoDigitTrueHit) relRecoDigitTrueHit.clear();
200 
201  //Build lookup tables for relations
202  createRelationLookup(relShaperDigitMCParticle, m_mcRelation, nDigits);
203  createRelationLookup(relShaperDigitTrueHit, m_trueRelation, nDigits);
204 
205  // Create fit tool object
206  NNWaveFitTool fitTool = m_fitter.getFitTool();
207 
208  // I. Group digits by sensor/side.
209  vector<pair<unsigned short, unsigned short> > sensorDigits;
210  VxdID lastSensorID(0);
211  size_t firstSensorDigit = 0;
212  for (size_t iDigit = 0; iDigit < nDigits; ++iDigit) {
213  const SVDShaperDigit& digit = *storeShaperDigits[iDigit];
214  VxdID sensorID = digit.getSensorID();
215  sensorID.setSegmentNumber(digit.isUStrip() ? 1 : 0);
216  if (sensorID != lastSensorID) { // we have a new sensor side
217  sensorDigits.push_back(make_pair(firstSensorDigit, iDigit));
218  firstSensorDigit = iDigit;
219  lastSensorID = sensorID;
220  }
221  }
222  // save last VxdID
223  sensorDigits.push_back(make_pair(firstSensorDigit, nDigits));
224 
225  // ICYCLE OVER SENSORS
226  for (auto id_indices : sensorDigits) {
227  // Retrieve parameters from sensorDigits
228  unsigned int firstDigit = id_indices.first;
229  unsigned int lastDigit = id_indices.second;
230  // Get VXDID and side from the first digit
231  const SVDShaperDigit& exampleDigit = *storeShaperDigits[firstDigit];
232  VxdID sensorID = exampleDigit.getSensorID();
233  bool isU = exampleDigit.isUStrip();
234 
235  // 2. Cycle through digits and form recodigits on the way.
236 
237  B2DEBUG(300, "Reconstructing digits " << firstDigit << " to " << lastDigit);
238  for (size_t iDigit = firstDigit; iDigit < lastDigit; ++iDigit) {
239 
240  const SVDShaperDigit& shaperDigit = *storeShaperDigits[iDigit];
241  unsigned short stripNo = shaperDigit.getCellID();
242  bool validDigit = true; // FIXME: We don't care about local run bad strips for now.
243  const double triggerBinSep = 4 * 1.96516; //in ns
244  double apvPhase = triggerBinSep * (0.5 + static_cast<int>(modeByte.getTriggerBin()));
245  // Get things from the database.
246  // Noise is good as it comes.
247  float stripNoiseADU = m_noiseCal.getNoise(sensorID, isU, stripNo);
248  // Some calibrations magic.
249  // FIXME: Only use calibration on real data. Until simulations correspond to
250  // default calibrtion, we cannot use it.
251  double stripSignalWidth = 270;
252  double stripT0 = isU ? 2.5 : -2.2;
253  if (m_calibratePeak) {
254  stripSignalWidth = 1.988 * m_pulseShapeCal.getWidth(sensorID, isU, stripNo);
255  stripT0 = m_pulseShapeCal.getPeakTime(sensorID, isU, stripNo)
256  - 0.25 * stripSignalWidth;
257  }
258 
259  B2DEBUG(300, "Strip parameters: stripNoiseADU: " << stripNoiseADU <<
260  " Width: " << stripSignalWidth <<
261  " T0: " << stripT0);
262 
263  // If the strip is not masked away, normalize samples (sample/stripNoise)
264  apvSamples normedSamples;
265  auto samples = shaperDigit.getSamples();
266  transform(samples.begin(), samples.end(), normedSamples.begin(),
267  bind2nd(divides<float>(), stripNoiseADU));
268  // FIXME: This won't work in 3 sample mode, we have no control over the number of non-zero samples.
269  validDigit = validDigit && pass3Samples(normedSamples, m_cutAdjacent);
270 
271  if (validDigit) {
272  zeroSuppress(normedSamples, m_cutAdjacent);
273  } else // only now we give up on the diigit
274  continue;
275 
276  // 3. Now we create and save the RecoDigit
277 
278  ostringstream os;
279  os << "Input to NNFitter: iDigit = " << iDigit << endl << "Samples: ";
280  copy(normedSamples.begin(), normedSamples.end(), ostream_iterator<double>(os, " "));
281  os << endl;
282  std::shared_ptr<nnFitterBinData> pStrip = m_fitter.getFit(normedSamples, stripSignalWidth);
283  os << "Output from NNWaveFitter: " << endl;
284  copy(pStrip->begin(), pStrip->end(), ostream_iterator<double>(os, " "));
285  os << endl;
286  // Apply strip time shift to pdf
287  fitTool.shiftInTime(*pStrip, -apvPhase - stripT0);
288  B2DEBUG(200, os.str());
289  // Calculate time and its error, amplitude and its error, and chi2
290  double stripTime, stripTimeError;
291  tie(stripTime, stripTimeError) = fitTool.getTimeShift(*pStrip);
292  // Now we have the cluster time pdf, so we can calculate amplitudes.
293  double stripAmplitude, stripAmplitudeError, stripChi2;
294  tie(stripAmplitude, stripAmplitudeError, stripChi2) =
295  fitTool.getAmplitudeChi2(normedSamples, stripTime, stripSignalWidth);
296  //De-normalize amplitudes and convert to electrons.
297  stripAmplitude = m_pulseShapeCal.getChargeFromADC(sensorID, isU, stripNo, stripAmplitude * stripNoiseADU);
298  stripAmplitudeError = m_pulseShapeCal.getChargeFromADC(sensorID, isU, stripNo, stripAmplitudeError * stripNoiseADU);
299  B2DEBUG(200, "RecoDigit " << iDigit << " Noise: " << m_pulseShapeCal.getChargeFromADC(sensorID, isU, stripNo, stripNoiseADU)
300  << " Time: " << stripTime << " +/- " << stripTimeError
301  << " Amplitude: " << stripAmplitude << " +/- " << stripAmplitudeError
302  << " Chi2: " << stripChi2
303  );
304 
305  // Finally, we save the RecoDigit and its relations.
306  map<unsigned int, float> mc_relations;
307  map<unsigned int, float> truehit_relations;
308  vector<pair<unsigned int, float> > digit_weights;
309  digit_weights.reserve(1);
310 
311  // Store relations to MCParticles and SVDTrueHits
312  fillRelationMap(m_mcRelation, mc_relations, iDigit);
313  fillRelationMap(m_trueRelation, truehit_relations, iDigit);
314  //Add digit to the RecoDigit->ShaperDigit relation list
315  digit_weights.emplace_back(iDigit, 1.0);
316 
317  //Store the RecoDigit into Datastore ...
318  int recoDigitIndex = storeRecoDigits.getEntries();
319  storeRecoDigits.appendNew(
320  SVDRecoDigit(sensorID, isU, shaperDigit.getCellID(), stripAmplitude,
321  stripAmplitudeError, stripTime, stripTimeError, *pStrip, stripChi2)
322  );
323 
324  //Create relations to RecoDigits
325  if (!mc_relations.empty()) {
326  relRecoDigitMCParticle.add(recoDigitIndex, mc_relations.begin(), mc_relations.end());
327  }
328  if (!truehit_relations.empty()) {
329  relRecoDigitTrueHit.add(recoDigitIndex, truehit_relations.begin(), truehit_relations.end());
330  }
331  relRecoDigitShaperDigit.add(recoDigitIndex, digit_weights.begin(), digit_weights.end());
332  } // CYCLE OVER SHAPERDIGITS
333 
334  } // CYCLE OVER SENSORS for items in sensorDigits
335 
336  B2DEBUG(100, "Number of RecoDigits: " << storeRecoDigits.getEntries());
337 
338 } // event()
339 
340 
Class for accessing objects in the database.
Definition: DBObjPtr.h:21
@ c_DontWriteOut
Object/array should be NOT saved by output modules.
Definition: DataStore.h:71
Base class for Modules.
Definition: Module.h:72
void setDescription(const std::string &description)
Sets the description of the module.
Definition: Module.cc:214
void setPropertyFlags(unsigned int propertyFlags)
Sets the flags for the module properties.
Definition: Module.cc:208
@ c_ParallelProcessingCertified
This module can be run in parallel processing mode safely (All I/O must be done through the data stor...
Definition: Module.h:80
Low-level class to create/modify relations between StoreArrays.
Definition: RelationArray.h:62
void add(index_type from, index_type to, weight_type weight=1.0)
Add a new element to the relation.
void clear() override
Clear all elements from the relation.
Class to store a single element of a relation.
Class to store SVD mode information.
Definition: SVDModeByte.h:69
baseType getTriggerBin() const
Get the triggerBin id.
Definition: SVDModeByte.h:140
float getNoise(const VxdID &sensorID, const bool &isU, const unsigned short &strip) const
This is the method for getting the noise.
double getChargeFromADC(const Belle2::VxdID &sensorID, const bool &isU, const unsigned short &strip, const double &pulseADC) const
Return the charge (number of electrons/holes) collected on a specific strip, given the number of ADC ...
float getPeakTime(const VxdID &sensorID, const bool &isU, const unsigned short &strip) const
Return the peaking time of the strip.
float getWidth(const VxdID &sensorID, const bool &isU, const unsigned short &strip) const
Return the width of the pulse shape for a given strip.
The SVD RecoDigit class.
Definition: SVDRecoDigit.h:43
The SVD ShaperDigit class.
VxdID getSensorID() const
Get the sensor ID.
APVFloatSamples getSamples() const
Get array of samples.
short int getCellID() const
Get strip ID.
bool isUStrip() const
Get strip direction.
The class holds arrays of bins and bin centers, and a wave generator object containing information on...
Definition: NNWaveFitTool.h:91
std::tuple< double, double, double > getAmplitudeChi2(const apvSamples &samples, double timeShift, double tau)
Return std::tuple with signal amplitude, its error, and chi2 of the fit.
std::tuple< double, double > getTimeShift(const nnFitterBinData &p)
Return std::tuple containing time shift and its error.
void shiftInTime(nnFitterBinData &p, double timeShift)
Shift the probability array in time.
void setNetwrok(const std::string &xmlData)
Set proper network definition file.
const NNWaveFitTool & getFitTool() const
Get a handle to a NNWaveFit object.
Definition: NNWaveFitter.h:98
std::shared_ptr< nnFitterBinData > getFit(const apvSamples &samples, double tau)
Fitting method Send data and get rseult structure.
std::string m_storeRecoDigitsName
Name of the collection to use for the SVDRecoDigits.
std::string m_relShaperDigitMCParticleName
Name of the relation between SVDShaperDigits and MCParticles.
bool m_writeRecoDigits
Write SVDRecoDigits? (no in normal operation)
virtual void initialize() override
Initialize the module.
std::string m_storeShaperDigitsName
Name of the collection to use for the SVDShaperDigits.
virtual void event() override
do the clustering
std::vector< const RelationElement * > RelationLookup
Container for a RelationArray Lookup table.
std::string m_relRecoDigitTrueHitName
Name of the relation between SVDRecoDigits and SVDTrueHits.
std::string m_storeTrueHitsName
Name of the collection to use for the SVDTrueHits.
std::string m_relRecoDigitMCParticleName
Name of the relation between SVDRecoDigits and MCParticles.
std::string m_timeFitterName
Name of the time fitter data xml.
void fillRelationMap(const RelationLookup &lookup, std::map< unsigned int, float > &relation, unsigned int index)
Add the relation from a given SVDShaperDigit index to a map.
std::string m_storeMCParticlesName
Name of the collection to use for the MCParticles.
std::string m_relShaperDigitTrueHitName
Name of the relation between SVDShaperDigits and SVDTrueHits.
RelationLookup m_trueRelation
Lookup table for SVDShaperDigit->SVDTrueHit relation.
SVDPulseShapeCalibrations m_pulseShapeCal
Calibrations: pusle shape and gain.
std::string m_svdEventInfoName
Name of the SVDEventInfo object.
SVDNoiseCalibrations m_noiseCal
Calibrations: noise.
void createRelationLookup(const RelationArray &relation, RelationLookup &lookup, size_t digits)
Create lookup maps for relations FIXME: This has to be significantly simplified here,...
RelationLookup m_mcRelation
Lookup table for SVDShaperDigit->MCParticle relation.
StoreObjPtr< SVDEventInfo > m_storeSVDEvtInfo
Storage for SVDEventInfo object.
bool m_calibratePeak
Use peak widths and peak time calibrations? Unitl this is also simulated, set to true only for testbe...
std::string m_relRecoDigitShaperDigitName
Name of the relation between SVDRecoDigits and SVDShaperDigits.
bool isRequired(const std::string &name="")
Ensure this array/object has been registered previously.
bool isOptional(const std::string &name="")
Tell the DataStore about an optional input.
const std::string & getName() const
Return name under which the object is saved in the DataStore.
bool registerInDataStore(DataStore::EStoreFlags storeFlags=DataStore::c_WriteOut)
Register the object/array in the DataStore.
Accessor to arrays stored in the data store.
Definition: StoreArray.h:113
T * appendNew()
Construct a new T object at the end of the array.
Definition: StoreArray.h:246
int getEntries() const
Get the number of objects in the array.
Definition: StoreArray.h:216
void clear() override
Delete all entries in this array.
Definition: StoreArray.h:207
Class to uniquely identify a any structure of the PXD and SVD.
Definition: VxdID.h:33
void setSegmentNumber(baseType segment)
Set the sensor segment.
Definition: VxdID.h:113
void addParam(const std::string &name, T &paramVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module.
Definition: Module.h:560
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
Definition: Module.h:650
Namespace to encapsulate code needed for simulation and reconstrucion of the SVD.
Definition: GeoSVDCreator.h:23
void zeroSuppress(T &a, double thr)
pass zero suppression
std::array< apvSampleBaseType, nAPVSamples > apvSamples
vector od apvSample BaseType objects
bool pass3Samples(const T &a, double thr)
pass 3-samples
Abstract base class for different kinds of events.