9#include <svd/modules/svdReconstruction/SVDNNClusterizerModule.h>
11#include <framework/datastore/StoreArray.h>
12#include <framework/logging/Logger.h>
14#include <vxd/geometry/GeoCache.h>
15#include <svd/geometry/SensorInfo.h>
17#include <mdst/dataobjects/MCParticle.h>
18#include <svd/dataobjects/SVDTrueHit.h>
19#include <svd/dataobjects/SVDShaperDigit.h>
20#include <svd/dataobjects/SVDRecoDigit.h>
21#include <svd/dataobjects/SVDCluster.h>
22#include <mva/dataobjects/DatabaseRepresentationOfWeightfile.h>
24#include <svd/reconstruction/NNWaveFitTool.h>
32using namespace std::placeholders;
47 B2DEBUG(200,
"SVDNNClusterizerModule ctor");
54 "RecoDigits collection name",
string(
""));
56 "Cluster collection name",
string(
""));
58 "TrueHit collection name",
string(
""));
60 "MCParticles collection name",
string(
""));
64 "Name of time fitter data file",
string(
"SVDTimeNet_6samples"));
65 addParam(
"CalibratePeak",
m_calibratePeak,
"Use calibrattion (vs. default) for peak widths and positions",
bool(
false));
70 "SN for digits to be considered for clustering",
m_cutAdjacent);
72 "SN for digits to be considered as seed",
m_cutSeed);
76 "Cluster size at which to switch to Analog head tail algorithm",
m_sizeHeadTail);
93 RelationArray relClusterRecoDigits(storeClusters, storeRecoDigits);
94 RelationArray relClusterTrueHits(storeClusters, storeTrueHits);
95 RelationArray relClusterMCParticles(storeClusters, storeMCParticles);
96 RelationArray relRecoDigitTrueHits(storeRecoDigits, storeTrueHits);
97 RelationArray relRecoDigitMCParticles(storeRecoDigits, storeMCParticles);
118 B2INFO(
" 1. COLLECTIONS:");
128 B2INFO(
" 2. CALIBRATION DATA:");
130 B2INFO(
" 4. CLUSTERING:");
148 if (!relation)
return;
150 lookup.resize(digits);
151 for (
const auto& element : relation) {
152 lookup[element.getFromIndex()] = &element;
157 std::map<unsigned int, float>& relation,
unsigned int index)
160 if (!lookup.empty() && lookup[index]) {
162 const unsigned int size = element.getSize();
164 for (
unsigned int i = 0; i < size; ++i) {
167 if (element.getWeight(i) < 0)
continue;
168 relation[element.getToIndex(i)] += element.getWeight(i);
177 if (!storeRecoDigits || !storeRecoDigits.
getEntries())
return;
179 size_t nDigits = storeRecoDigits.
getEntries();
180 B2DEBUG(90,
"Initial size of RecoDigits array: " << nDigits);
189 storeClusters.
clear();
191 RelationArray relClusterMCParticle(storeClusters, storeMCParticles,
193 if (relClusterMCParticle) relClusterMCParticle.
clear();
195 RelationArray relClusterRecoDigit(storeClusters, storeRecoDigits,
197 if (relClusterRecoDigit) relClusterRecoDigit.
clear();
199 RelationArray relClusterTrueHit(storeClusters, storeTrueHits,
201 if (relClusterTrueHit) relClusterTrueHit.
clear();
211 vector<pair<unsigned short, unsigned short> > sensorDigits;
212 VxdID lastSensorID(0);
213 size_t firstSensorDigit = 0;
214 for (
size_t iDigit = 0; iDigit < nDigits; ++iDigit) {
218 if (sensorID != lastSensorID) {
219 sensorDigits.push_back(make_pair(firstSensorDigit, iDigit));
220 firstSensorDigit = iDigit;
221 lastSensorID = sensorID;
225 sensorDigits.push_back(make_pair(firstSensorDigit, nDigits));
228 for (
auto id_indices : sensorDigits) {
230 unsigned int firstDigit = id_indices.first;
231 unsigned int lastDigit = id_indices.second;
233 const SVDRecoDigit& sampleRecoDigit = *storeRecoDigits[firstDigit];
235 bool isU = sampleRecoDigit.
isUStrip();
250 vector<pair<size_t, size_t> > stripGroups;
251 size_t firstClusterDigit = firstDigit;
252 size_t lastClusterDigit = firstDigit;
253 short lastStrip = -2;
255 B2DEBUG(300,
"Clustering digits " << firstDigit <<
" to " << lastDigit);
256 for (
size_t iDigit = firstDigit; iDigit < lastDigit; ++iDigit) {
258 const SVDRecoDigit& recoDigit = *storeRecoDigits[iDigit];
259 unsigned short currentStrip = recoDigit.
getCellID();
260 B2DEBUG(300,
"Digit " << iDigit <<
", strip: " << currentStrip <<
", lastStrip: " << lastStrip);
261 B2DEBUG(300,
"First CD: " << firstClusterDigit <<
" Last CD: " << lastClusterDigit);
264 bool consecutive = ((currentStrip - lastStrip) == 1);
265 lastStrip = currentStrip;
267 B2DEBUG(300, (consecutive ?
"consecutive" :
"gap"));
270 if (!consecutive && (firstClusterDigit < lastClusterDigit)) {
271 B2DEBUG(300,
"Saving (" << firstClusterDigit <<
", " << lastClusterDigit <<
")");
272 stripGroups.emplace_back(firstClusterDigit, lastClusterDigit);
277 lastClusterDigit = iDigit + 1;
279 firstClusterDigit = iDigit;
280 lastClusterDigit = iDigit + 1;
284 if (firstClusterDigit < lastClusterDigit) {
285 B2DEBUG(300,
"Saving (" << firstClusterDigit <<
", " << lastDigit <<
")");
286 stripGroups.emplace_back(firstClusterDigit, lastDigit);
290 os <<
"StripGroups: " << endl;
291 for (
auto item : stripGroups) {
292 os <<
"(" << item.first <<
", " << item.second <<
"), ";
294 B2DEBUG(300, os.str());
303 vector<unsigned short> stripNumbers;
304 vector<float> stripPositions;
305 vector<float> stripNoises;
306 vector<float> stripGains;
307 vector<float> timeShifts;
308 vector<float> waveWidths;
309 vector<apvSamples> storedNormedSamples;
310 vector<SVDRecoDigit::OutputProbArray> storedPDFs;
313 for (
auto clusterBounds : stripGroups) {
315 unsigned short clusterSize = clusterBounds.second - clusterBounds.first;
316 assert(clusterSize > 0);
318 stripNumbers.clear();
319 stripPositions.clear();
325 for (
size_t iDigit = clusterBounds.first; iDigit < clusterBounds.second; ++iDigit) {
329 const SVDRecoDigit& recoDigit = *storeRecoDigits[iDigit];
331 unsigned short stripNo = recoDigit.
getCellID();
332 stripNumbers.push_back(stripNo);
335 stripNoises.push_back(
341 double peakWidth = 270;
342 double timeShift = isU ? 4.0 : 0.0;
348 waveWidths.push_back(peakWidth);
349 timeShifts.push_back(timeShift);
350 stripPositions.push_back(
357 B2FATAL(
"Missing SVDRecoDigits->SVDShaperDigits relation. This should not happen.");
359 transform(samples.begin(), samples.end(), normedSamples.begin(),
360 bind(divides<float>(), _1, stripNoiseADU));
365 storedNormedSamples.emplace_back(normedSamples);
371 float clusterNoise =
sqrt(
373 * inner_product(stripNoises.begin(), stripNoises.end(), stripNoises.begin(), 0.0)
375 B2DEBUG(200,
"RMS cluster noise: " << clusterNoise);
379 shared_ptr<nnFitterBinData> pStrip;
382 fill(pCluster.begin(), pCluster.end(),
double(1.0));
384 for (
size_t iClusterStrip = 0; iClusterStrip < clusterSize; ++iClusterStrip) {
385 size_t iDigit = clusterBounds.first + iClusterStrip;
387 os1 <<
"Input to NNFitter: iDigit = " << iDigit << endl <<
"Samples: ";
388 copy(storedNormedSamples[iClusterStrip].begin(), storedNormedSamples[iClusterStrip].end(),
389 ostream_iterator<double>(os1,
" "));
391 os1 <<
"PDF from RecoDigit: " << endl;
392 copy(storedPDFs[iClusterStrip].begin(), storedPDFs[iClusterStrip].end(), ostream_iterator<double>(os1,
" "));
394 fitTool.
multiply(pCluster, storedPDFs[iClusterStrip]);
395 os1 <<
"Accummulated: " << endl;
396 copy(pCluster.begin(), pCluster.end(), ostream_iterator<double>(os1,
" "));
397 B2DEBUG(200, os1.str());
400 double clusterTime, clusterTimeErr;
401 tie(clusterTime, clusterTimeErr) = fitTool.
getTimeShift(pCluster);
402 B2DEBUG(200,
"Time: " << clusterTime <<
" +/- " << clusterTimeErr);
406 vector<double> stripAmplitudes(stripNoises.size());
407 vector<double> stripAmplitudeErrors(stripNoises.size());
408 double clusterChi2 = 0.0;
409 for (
size_t iClusterStrip = 0; iClusterStrip < clusterSize; ++iClusterStrip) {
410 size_t iDigit = clusterBounds.first + iClusterStrip;
411 double snAmp, snAmpError, chi2;
412 tie(snAmp, snAmpError, chi2) =
413 fitTool.
getAmplitudeChi2(storedNormedSamples[iClusterStrip], clusterTime, waveWidths[iClusterStrip]);
415 stripAmplitudes[iClusterStrip] = stripNoises[iClusterStrip] * snAmp;
416 stripAmplitudeErrors[iClusterStrip] = stripNoises[iClusterStrip] * snAmpError;
418 B2DEBUG(200,
"Digit " << iDigit <<
" Noise: " << stripNoises[iClusterStrip]
419 <<
" Amplitude: " << stripAmplitudes[iClusterStrip]
420 <<
" +/- " << stripAmplitudeErrors[iClusterStrip]
425 float clusterCharge = accumulate(stripAmplitudes.begin(), stripAmplitudes.end(), 0.0);
426 float clusterChargeError =
sqrt(
427 inner_product(stripAmplitudeErrors.begin(), stripAmplitudeErrors.end(),
428 stripAmplitudeErrors.begin(), 0.0)
430 float clusterSN = (clusterChargeError > 0) ? clusterCharge / clusterChargeError : clusterCharge;
432 clusterChi2 /= clusterSize;
434 size_t seedIndex = distance(stripAmplitudes.begin(), max_element(
435 stripAmplitudes.begin(), stripAmplitudes.end()));
436 float clusterSeedCharge = stripAmplitudes[seedIndex];
437 B2DEBUG(200,
"Cluster parameters:");
438 B2DEBUG(200,
"Charge: " << clusterCharge <<
" +/- " << clusterChargeError);
439 B2DEBUG(200,
"Seed: " << clusterSeedCharge <<
" +/- " << stripAmplitudeErrors[seedIndex]);
440 B2DEBUG(200,
"S/N: " << clusterSN);
441 B2DEBUG(200,
"chi2: " << clusterChi2);
444 float clusterPosition, clusterPositionError;
453 clusterPosition = 1.0 / clusterCharge * inner_product(
454 stripAmplitudes.begin(), stripAmplitudes.end(), stripPositions.begin(), 0.0
458 if (clusterSize == 1) {
459 clusterPositionError = pitch * phantomCharge / (clusterCharge + phantomCharge);
461 clusterPositionError = pitch * phantomCharge / clusterCharge;
464 float leftStripCharge = stripAmplitudes.front();
465 float leftPos = stripPositions.front();
466 float rightStripCharge = stripAmplitudes.back();
467 float rightPos = stripPositions.back();
468 float centreCharge = (clusterCharge - leftStripCharge - rightStripCharge) / (clusterSize - 2);
469 leftStripCharge = (leftStripCharge < centreCharge) ? leftStripCharge : centreCharge;
470 rightStripCharge = (rightStripCharge < centreCharge) ? rightStripCharge : centreCharge;
471 clusterPosition = 0.5 * (leftPos + rightPos)
472 + 0.5 * (rightStripCharge - leftStripCharge) / centreCharge * pitch;
475 float landauHead = leftStripCharge / centreCharge;
476 double landauTail = rightStripCharge / centreCharge;
477 clusterPositionError = 0.5 * pitch *
sqrt(1.0 / sn / sn
478 + 0.5 * landauHead * landauHead + 0.5 * landauTail * landauTail);
484 map<unsigned int, float> mc_relations;
485 map<unsigned int, float> truehit_relations;
486 vector<pair<unsigned int, float> > digit_weights;
487 digit_weights.reserve(clusterSize);
489 for (
size_t iDigit = clusterBounds.first; iDigit < clusterBounds.second; ++iDigit) {
494 digit_weights.emplace_back(iDigit, stripAmplitudes[iDigit - clusterBounds.first]);
499 VxdID clusterSensorID = sensorID;
502 SVDCluster(sensorID, isU, clusterPosition, clusterPositionError, clusterTime,
503 clusterTimeErr, clusterCharge, clusterSeedCharge, clusterSize, clusterSN, clusterChi2)
507 if (!mc_relations.empty()) {
508 relClusterMCParticle.
add(clsIndex, mc_relations.begin(), mc_relations.end());
510 if (!truehit_relations.empty()) {
511 relClusterTrueHit.
add(clsIndex, truehit_relations.begin(), truehit_relations.end());
513 relClusterRecoDigit.
add(clsIndex, digit_weights.begin(), digit_weights.end());
518 B2DEBUG(100,
"Number of clusters: " << storeClusters.
getEntries());
Class for accessing objects in the database.
void setDescription(const std::string &description)
Sets the description of the module.
void setPropertyFlags(unsigned int propertyFlags)
Sets the flags for the module properties.
@ c_ParallelProcessingCertified
This module can be run in parallel processing mode safely (All I/O must be done through the data stor...
Low-level class to create/modify relations between StoreArrays.
void add(index_type from, index_type to, weight_type weight=1.0)
Add a new element to the relation.
void clear() override
Clear all elements from the relation.
Class to store a single element of a relation.
TO * getRelatedTo(const std::string &name="", const std::string &namedRelation="") const
Get the object to which this object has a relation.
The SVD Cluster class This class stores all information about reconstructed SVD clusters.
float getNoise(const VxdID &sensorID, const bool &isU, const unsigned short &strip) const
This is the method for getting the noise.
double getChargeFromADC(const Belle2::VxdID &sensorID, const bool &isU, const unsigned short &strip, const double &pulseADC) const
Return the charge (number of electrons/holes) collected on a specific strip, given the number of ADC ...
float getPeakTime(const VxdID &sensorID, const bool &isU, const unsigned short &strip) const
Return the peaking time of the strip.
float getWidth(const VxdID &sensorID, const bool &isU, const unsigned short &strip) const
Return the width of the pulse shape for a given strip.
OutputProbArray getProbabilities() const
Get signal time pdf.
VxdID getSensorID() const
Get the sensor ID.
short int getCellID() const
Get strip ID.
bool isUStrip() const
Get strip direction.
The SVD ShaperDigit class.
APVFloatSamples getSamples() const
Get array of samples.
void setNetwrok(const std::string &xmlData)
Set proper network definition file.
const NNWaveFitTool & getFitTool() const
Get a handle to a NNWaveFit object.
const nnFitterBinData & getBinCenters() const
Get bin times of the network output.
std::string m_storeRecoDigitsName
Name of the collection to use for the SVDRecoDigits.
virtual void initialize() override
Initialize the module.
virtual void event() override
do the clustering
std::vector< const RelationElement * > RelationLookup
Container for a RelationArray Lookup table.
std::string m_relRecoDigitTrueHitName
Name of the relation between SVDRecoDigits and SVDTrueHits.
double m_cutCluster
Cluster cut in units of m_elNoise.
std::string m_storeTrueHitsName
Name of the collection to use for the SVDTrueHits.
std::string m_relRecoDigitMCParticleName
Name of the relation between SVDRecoDigits and MCParticles.
std::string m_timeFitterName
Name of the time fitter (db label)
void fillRelationMap(const RelationLookup &lookup, std::map< unsigned int, float > &relation, unsigned int index)
Add the relation from a given SVDRecoDigit index to a map.
std::string m_storeMCParticlesName
Name of the collection to use for the MCParticles.
RelationLookup m_trueRelation
Lookup table for SVDRecoDigit->SVDTrueHit relation.
SVDPulseShapeCalibrations m_pulseShapeCal
Calibrations: pusle shape and gain.
SVDNoiseCalibrations m_noiseCal
Calibrations: noise.
NNWaveFitter m_fitter
Time fitter.
int m_sizeHeadTail
Size of the cluster at which we switch from Center of Gravity to Analog Head Tail.
void createRelationLookup(const RelationArray &relation, RelationLookup &lookup, size_t digits)
Create lookup maps for relations We do not use the RelationIndex as we know much more about the relat...
std::string m_storeClustersName
Name of the collection to use for the SVDClusters.
double m_cutSeed
Seed cut in units of m_elNoise.
std::string m_relClusterMCParticleName
Name of the relation between SVDClusters and MCParticles.
SVDNNClusterizerModule()
Constructor defining the parameters.
RelationLookup m_mcRelation
Lookup table for SVDRecoDigit->MCParticle relation.
std::string m_relClusterRecoDigitName
Name of the relation between SVDClusters and SVDRecoDigits.
bool m_calibratePeak
Use peak widths and peak time calibrations? Unitl this is also simulated, set to true only for testbe...
double m_cutAdjacent
Noise (cluster member) cut in units of m_elNoise.
std::string m_relClusterTrueHitName
Name of the relation between SVDClusters and SVDTrueHits.
Specific implementation of SensorInfo for SVD Sensors which provides additional sensor specific infor...
const ROOT::Math::XYZVector & getLorentzShift(double uCoord, double vCoord) const
Calculate Lorentz shift along a given coordinate in a magnetic field at a given position.
const std::string & getName() const
Return name under which the object is saved in the DataStore.
bool isRequired(const std::string &name="")
Ensure this array/object has been registered previously.
bool isOptional(const std::string &name="")
Tell the DataStore about an optional input.
bool registerInDataStore(DataStore::EStoreFlags storeFlags=DataStore::c_WriteOut)
Register the object/array in the DataStore.
Accessor to arrays stored in the data store.
T * appendNew()
Construct a new T object at the end of the array.
int getEntries() const
Get the number of objects in the array.
void clear() override
Delete all entries in this array.
const SensorInfoBase & getSensorInfo(Belle2::VxdID id) const
Return a referecne to the SensorInfo of a given SensorID.
static GeoCache & getInstance()
Return a reference to the singleton instance.
double getVCellPosition(int vID) const
Return the position of a specific strip/pixel in v direction.
double getUPitch(double v=0) const
Return the pitch of the sensor.
double getUCellPosition(int uID, int vID=-1) const
Return the position of a specific strip/pixel in u direction.
double getVPitch(double v=0) const
Return the pitch of the sensor.
Class to uniquely identify a any structure of the PXD and SVD.
void setSegmentNumber(baseType segment)
Set the sensor segment.
void addParam(const std::string &name, T ¶mVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module.
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
double sqrt(double a)
sqrt for double
Namespace to encapsulate code needed for simulation and reconstrucion of the SVD.
std::array< apvSampleBaseType, nAPVSamples > apvSamples
vector od apvSample BaseType objects
std::vector< double > nnFitterBinData
Vector of values defined for bins, such as bin times or bin probabilities.
void zeroSuppress(T &a, double thr)
pass zero suppression
Abstract base class for different kinds of events.