Belle II Software development
CDCDedxPIDCreatorModule.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include <cdc/modules/CDCDedxPID/CDCDedxPIDCreatorModule.h>
10#include <cdc/geometry/CDCGeometryPar.h>
11#include <cdc/translators/LinearGlobalADCCountTranslator.h>
12#include <cdc/translators/RealisticTDCCountTranslator.h>
13#include <cdc/modules/CDCDedxPID/LineHelper.h>
14#include <framework/core/Environment.h>
15#include <TRandom.h>
16#include <cmath>
17#include <algorithm>
18#include <map>
19#include <vector>
20
21using namespace std;
22
23namespace Belle2 {
28
29 using namespace CDC;
30
31 //-----------------------------------------------------------------
33 //-----------------------------------------------------------------
34
35 REG_MODULE(CDCDedxPIDCreator);
36
37 //-----------------------------------------------------------------
38 // Implementation
39 //-----------------------------------------------------------------
40
42
43 {
44 // set module description
45 setDescription("Module that creates PID likelihoods from CDC hit information stored in CDCDedxHits "
46 "using parameterized means and resolutions.");
48
49 addParam("removeLowest", m_removeLowest,
50 "Portion of events with low dE/dx that should be discarded", double(0.05));
51 addParam("removeHighest", m_removeHighest,
52 "Portion of events with high dE/dx that should be discarded", double(0.25));
53 addParam("useBackHalfCurlers", m_useBackHalfCurlers,
54 "Whether to use the back half of curlers", false);
55 addParam("trackLevel", m_trackLevel,
56 "ONLY USEFUL FOR MC: Use track-level MC (generate truncated mean from predicted mean and sigma using MC truth). "
57 "If false, use hit-level MC (use truncated mean determined from hits)", true);
58 }
59
63
65 {
66 m_tracks.isRequired();
67 m_hits.isOptional(); // in order to run also with old cdst's where this collection doesn't exist
68 m_mcParticles.isOptional();
69 m_TTDInfo.isOptional();
70 m_likelihoods.registerInDataStore();
71 m_tracks.registerRelationTo(m_likelihoods);
72 m_dedxTracks.registerInDataStore();
73 m_tracks.registerRelationTo(m_dedxTracks);
74
75 m_nLayerWires[0] = 1280;
76 for (int i = 1; i < 9; ++i) {
77 m_nLayerWires[i] = m_nLayerWires[i - 1] + 6 * (160 + (i - 1) * 32);
78 }
79
80 if (not m_trackLevel)
81 B2WARNING("Hit-level MC still needs a precise calibration to perform well! Until then please use track-level MC.");
82
83 }
84
86 {
87 // check if CDCDedxHits are present; return if not.
88 if (not m_hits.isValid()) {
90 if (m_warnCount < 10) {
91 B2WARNING("StoreArray 'CDCDedxHits' does not exist, returning. Probably running on old cdst.");
92 } else if (m_warnCount == 10) {
93 B2WARNING("StoreArray 'CDCDedxHits' does not exist, returning. ...message will be suppressed now.");
94 }
95 return;
96 }
97
98 // clear output collections
99 m_likelihoods.clear();
100 m_dedxTracks.clear();
101
102 // CDC geometry parameters and translators
103 const auto& cdcgeo = CDCGeometryPar::Instance();
104 LinearGlobalADCCountTranslator adcTranslator;
105 RealisticTDCCountTranslator tdcTranslator;
106
107 // is data or MC ?
108 bool isData = not Environment::Instance().isMC();
109
110 // track independent calibration constants
111 double runGain = isData ? m_DBRunGain->getRunGain() : 1.0;
112 double timeGain = 1;
113 double timeReso = 1; // this is multiplicative constant
114 if (isData and m_TTDInfo.isValid() and m_TTDInfo->hasInjection()) {
115 timeGain = m_DBInjectTime->getCorrection("mean", m_TTDInfo->isHER(), m_TTDInfo->getTimeSinceLastInjectionInMicroSeconds());
116 timeReso = m_DBInjectTime->getCorrection("reso", m_TTDInfo->isHER(), m_TTDInfo->getTimeSinceLastInjectionInMicroSeconds());
117 }
118 double scale = m_DBScaleFactor->getScaleFactor(); // scale factor to make electron dE/dx ~ 1
119 if (scale == 0) {
120 B2ERROR("Scale factor from DB is zero! Will be set to one");
121 scale = 1;
122 }
123
124 // loop over tracks
125 for (const auto& track : m_tracks) {
126 // track fit result
127 const auto* fitResult = track.getTrackFitResultWithClosestMass(Const::pion);
128 if (not fitResult) {
129 B2WARNING("No related fit for track, skip it.");
130 continue;
131 }
132
133 // hits of this track
134 const auto hits = track.getRelationsTo<CDCDedxHit>();
135 if (hits.size() == 0) continue;
136
137 // track momentum
138 const auto& trackMom = fitResult->getMomentum();
139 double theta = trackMom.Theta();
140 double cosTheta = std::cos(theta);
141 double sinTheta = std::sin(theta);
142
143 // track dependent calibration constants
144 double cosCor = isData ? m_DBCosineCor->getMean(cosTheta) : 1.0;
145 bool isEdge = std::abs(cosTheta + 0.860) < 0.010 or std::abs(cosTheta - 0.955) <= 0.005;
146 double cosEdgeCor = (isData and isEdge) ? m_DBCosEdgeCor->getMean(cosTheta) : 1.0;
147
148 // MC particle
149 const auto* mcParticle = isData ? nullptr : track.getRelated<MCParticle>();
150
151 // debug output
152 CDCDedxTrack* dedxTrack = m_dedxTracks.appendNew();
153 if (dedxTrack) {
154 dedxTrack->m_track = track.getArrayIndex();
155 dedxTrack->m_charge = fitResult->getChargeSign();
156 dedxTrack->m_cosTheta = cosTheta;
157 dedxTrack->m_p = trackMom.R();
158 if (isData and m_TTDInfo.isValid() and m_TTDInfo->hasInjection()) {
159 dedxTrack->m_injring = m_TTDInfo->isHER();
160 dedxTrack->m_injtime = m_TTDInfo->getTimeSinceLastInjectionInMicroSeconds();
161 }
162 if (mcParticle) {
163 dedxTrack->m_pdg = mcParticle->getPDG();
164 dedxTrack->m_mcmass = mcParticle->getMass();
165 const auto* mother = mcParticle->getMother();
166 dedxTrack->m_motherPDG = mother ? mother->getPDG() : 0;
167 const auto& trueMom = mcParticle->getMomentum();
168 dedxTrack->m_pTrue = trueMom.R();
169 dedxTrack->m_cosThetaTrue = std::cos(trueMom.Theta());
170 }
171 dedxTrack->m_scale = scale;
172 dedxTrack->m_cosCor = cosCor;
173 dedxTrack->m_cosEdgeCor = cosEdgeCor;
174 dedxTrack->m_runGain = runGain;
175 dedxTrack->m_timeGain = timeGain;
176 dedxTrack->m_timeReso = timeReso;
177 }
178
179 // loop over hits
180 int lastLayer = -1;
181 double pCDC = 0;
182 std::map<int, DEDX> dedxWires;
183 for (const auto& hit : hits) {
184 // wire numbering: layer and superlayer
185 const auto& wireID = hit.getWireID();
186 int layer = wireID.getILayer(); // layer within superlayer
187 int superlayer = wireID.getISuperLayer();
188 int currentLayer = (superlayer == 0) ? layer : (8 + (superlayer - 1) * 6 + layer); // continuous layer number
189 if (not m_useBackHalfCurlers and currentLayer < lastLayer) break;
190 lastLayer = currentLayer;
191
192 // track momentum at the first hit
193 if (pCDC == 0) pCDC = hit.getPOCAMomentum().R();
194
195 // drift cell
196 double innerRadius = cdcgeo.innerRadiusWireLayer()[currentLayer];
197 double outerRadius = cdcgeo.outerRadiusWireLayer()[currentLayer];
198 const ROOT::Math::XYZVector& wirePosF = cdcgeo.wireForwardPosition(wireID, CDCGeometryPar::c_Aligned);
199 double wireRadius = wirePosF.Rho();
200 int nWires = cdcgeo.nWiresInLayer(currentLayer);
201 double topHeight = outerRadius - wireRadius;
202 double bottomHeight = wireRadius - innerRadius;
203 double topHalfWidth = M_PI * outerRadius / nWires;
204 double bottomHalfWidth = M_PI * innerRadius / nWires;
205 DedxDriftCell cell(DedxPoint(-topHalfWidth, topHeight),
206 DedxPoint(topHalfWidth, topHeight),
207 DedxPoint(bottomHalfWidth, -bottomHeight),
208 DedxPoint(-bottomHalfWidth, -bottomHeight));
209
210 // length of a track within the drift cell
211 double doca = hit.getSignedDOCAXY();
212 double entAng = hit.getEntranceAngle();
213 double celldx = cell.dx(doca, entAng) / sinTheta; // length of a track in the cell
214 if (not cell.isValid()) continue;
215
216 // wire gain calibration (iwire is a continuous wire number)
217 int wire = wireID.getIWire();
218 int iwire = (superlayer == 0) ? 160 * layer + wire : m_nLayerWires[superlayer - 1] + (160 + 32 * (superlayer - 1)) * layer + wire;
219 double wiregain = isData ? m_DBWireGains->getWireGain(iwire) : 1.0;
220
221 // re-scaled (RS) doca and entAng variable: map to square cell
222 double cellHalfWidth = M_PI * wireRadius / nWires;
223 double cellHeight = topHeight + bottomHeight;
224 double cellR = 2 * cellHalfWidth / cellHeight;
225 double tana = std::max(std::min(std::tan(entAng), 1e10), -1e10); // this fixes bug in CDCDedxPIDModule near +-pi/2
226 double docaRS = doca * std::sqrt((1 + cellR * cellR * tana * tana) / (1 + tana * tana));
227 double normDocaRS = docaRS / cellHalfWidth;
228 double entAngRS = std::atan(tana / cellR);
229
230 // one and two dimensional corrections
231 double onedcor = isData ? m_DB1DCell->getMean(currentLayer, entAngRS) : 1.0;
232 double twodcor = isData ? m_DB2DCell->getMean(currentLayer, normDocaRS, entAngRS) : 1.0;
233
234 // total correction
235 double correction = runGain * cosCor * cosEdgeCor * timeGain * wiregain * twodcor * onedcor;
236
237 // calibrated ADC count
238 double adcCount = isData ? m_DBNonlADC->getCorrectedADC(hit.getADCCount(), currentLayer) : hit.getADCCount();
239 double adcCalibrated = correction != 0 ? adcCount / scale / correction : 0;
240
241 // merge dEdx measurements on single wires; take active wires only
242 if (correction != 0) dedxWires[iwire].add(hit, iwire, currentLayer, celldx, adcCalibrated);
243
244 // debug output
245 if (dedxTrack) {
246 dedxTrack->m_pCDC = pCDC;
247 const auto& pocaMom = hit.getPOCAMomentum();
248 double pocaPhi = pocaMom.Phi();
249 double pocaTheta = pocaMom.Theta();
250 double pocaZ = hit.getPOCAOnWire().Z();
251 double hitCharge = adcTranslator.getCharge(adcCount, wireID, false, pocaZ, pocaPhi);
252 double driftDRealistic = tdcTranslator.getDriftLength(hit.getTDCCount(), wireID, 0, true, pocaZ, pocaPhi, pocaTheta);
253 double driftDRealisticRes = tdcTranslator.getDriftLengthResolution(driftDRealistic, wireID, true, pocaZ, pocaPhi, pocaTheta);
254 double cellDedx = adcCalibrated / celldx;
255
256 dedxTrack->addHit(wire, iwire, currentLayer, doca, docaRS, entAng, entAngRS,
257 adcCount, hit.getADCCount(), hitCharge, celldx * sinTheta, cellDedx, cellHeight, cellHalfWidth,
258 hit.getTDCCount(), driftDRealistic, driftDRealisticRes, wiregain, twodcor, onedcor,
259 hit.getFoundByTrackFinder(), hit.getWeightPionHypo(), hit.getWeightKaonHypo(), hit.getWeightProtonHypo());
260 }
261
262 } // end of loop over hits
263
264 // merge dEdx measurements in layers
265 std::map<int, DEDX> dedxLayers;
266 for (const auto& dedxWire : dedxWires) {
267 const auto& dedx = dedxWire.second;
268 dedxLayers[dedx.cLayer].add(dedx);
269 }
270
271 // push dEdx values to a vector
272 std::vector<double> dedxValues;
273 for (const auto& dedxLayer : dedxLayers) {
274 const auto& dedx = dedxLayer.second;
275 if (dedx.dx > 0 and dedx.dE > 0) {
276 dedxValues.push_back(dedx.dE / dedx.dx);
277 // debug output
278 if (dedxTrack) dedxTrack->addDedx(dedx.nhits, dedx.cWire, dedx.cLayer, dedx.dx, dedxValues.back());
279 }
280 }
281 if (dedxValues.empty()) continue;
282
283 // sort dEdx values
284 std::sort(dedxValues.begin(), dedxValues.end());
285
286 // calculate mean
287 double mean = 0;
288 for (auto x : dedxValues) mean += x;
289 mean /= dedxValues.size();
290
291 // calculate truncated mean and error
292 int lowEdgeTrunc = int(dedxValues.size() * m_removeLowest + 0.51);
293 int highEdgeTrunc = int(dedxValues.size() * (1 - m_removeHighest) + 0.51);
294 double truncatedMean = 0;
295 double sumOfSquares = 0;
296 int numValues = 0;
297 for (int i = lowEdgeTrunc; i < highEdgeTrunc; i++) {
298 double x = dedxValues[i];
299 truncatedMean += x;
300 sumOfSquares += x * x;
301 numValues++;
302 }
303 if (numValues > 0) {
304 truncatedMean /= numValues;
305 } else {
306 truncatedMean = mean;
307 numValues = dedxValues.size();
308 }
309 double truncatedError = numValues > 1 ? std::sqrt(sumOfSquares / numValues - truncatedMean * truncatedMean) / (numValues - 1) : 0;
310
311 // apply the saturation correction only to data (the so called "hadron correction")
312 double correctedMean = isData ? m_DBHadronCor->getCorrectedMean(truncatedMean, cosTheta) : truncatedMean;
313
314 // track level MC (e.g. replacing truncated mean with a generated one)
315 if (m_trackLevel and mcParticle) {
316 double mass = mcParticle->getMass();
317 if (mass > 0) {
318 double mcMean = m_DBMeanPars->getMean(pCDC / mass);
319 double mcSigma = m_DBSigmaPars->getSigma(mcMean, numValues, cosTheta, timeReso);
320 correctedMean = gRandom->Gaus(mcMean, mcSigma);
321 while (correctedMean < 0) correctedMean = gRandom->Gaus(mcMean, mcSigma);
322 // debug output
323 if (dedxTrack) dedxTrack->m_simDedx = correctedMean;
324 }
325 }
326
327 // calculate log likelihoods
328 double cdcLogL[Const::ChargedStable::c_SetSize] = {0};
329 for (const auto& chargedStable : Const::chargedStableSet) {
330 double betagamma = pCDC / chargedStable.getMass();
331 double predictedMean = m_DBMeanPars->getMean(betagamma);
332 double predictedSigma = m_DBSigmaPars->getSigma(predictedMean, numValues, cosTheta, timeReso);
333 if (predictedSigma <= 0) B2ERROR("Predicted sigma is not positive for PDG = " << chargedStable.getPDGCode());
334 double chi = (correctedMean - predictedMean) / predictedSigma;
335 int index = chargedStable.getIndex();
336 cdcLogL[index] = -0.5 * chi * chi;
337 // debug output
338 if (dedxTrack) {
339 dedxTrack->m_predmean[index] = predictedMean;
340 dedxTrack->m_predres[index] = predictedSigma;
341 dedxTrack->m_cdcChi[index] = chi;
342 dedxTrack->m_cdcLogl[index] = cdcLogL[index];
343 }
344 }
345
346 // save log likelihoods
347 auto* likelihoods = m_likelihoods.appendNew(cdcLogL);
348 track.addRelationTo(likelihoods);
349
350 // debug output
351 if (dedxTrack) {
352 double fullLength = 0;
353 for (const auto& dedxLayer : dedxLayers) fullLength += dedxLayer.second.dx;
354 dedxTrack->m_length = fullLength;
355 dedxTrack->m_dedxAvg = mean;
356 dedxTrack->m_dedxAvgTruncatedNoSat = truncatedMean;
357 dedxTrack->m_dedxAvgTruncatedErr = truncatedError;
358 dedxTrack->m_dedxAvgTruncated = correctedMean;
359 dedxTrack->m_lNHitsUsed = numValues;
360 track.addRelationTo(dedxTrack);
361 }
362
363 } // end of loop over tracks
364
365 }
366
368} // end Belle2 namespace
369
Class to store CDC hit information needed for dedx.
Definition CDCDedxHit.h:26
bool m_useBackHalfCurlers
whether to use the back half of curlers
StoreObjPtr< EventLevelTriggerTimeInfo > m_TTDInfo
injection time info
DBObjPtr< CDCDedxRunGain > m_DBRunGain
Run gain DB object.
DBObjPtr< CDCDedxMeanPars > m_DBMeanPars
dE/dx mean parameters
DBObjPtr< CDCDedxHadronCor > m_DBHadronCor
hadron saturation parameters
double m_removeHighest
portion of events with high dE/dx to discard
DBObjPtr< CDCDedxCosineEdge > m_DBCosEdgeCor
non-linearly ACD correction DB object
DBObjPtr< CDCDedxADCNonLinearity > m_DBNonlADC
non-linearly ACD correction DB object
DBObjPtr< CDCDedx1DCell > m_DB1DCell
1D correction DB object
StoreArray< CDCDedxLikelihood > m_likelihoods
collection of PID likelihoods
DBObjPtr< CDCDedx2DCell > m_DB2DCell
2D correction DB object
DBObjPtr< CDCDedxCosineCor > m_DBCosineCor
Electron saturation correction DB object.
StoreArray< Track > m_tracks
collection of tracks
int m_nLayerWires[9]
lookup table for number of wires per superlayer (indexed by superlayer)
DBObjPtr< CDCDedxSigmaPars > m_DBSigmaPars
dE/dx resolution parameters
DBObjPtr< CDCDedxWireGain > m_DBWireGains
Wire gain DB object.
bool m_trackLevel
whether to use track-level or hit-level MC
StoreArray< CDCDedxHit > m_hits
collection of hits
StoreArray< MCParticle > m_mcParticles
collection of MC particles
double m_removeLowest
portion of events with low dE/dx to discard
StoreArray< CDCDedxTrack > m_dedxTracks
collection of debug output
DBObjPtr< CDCDedxInjectionTime > m_DBInjectTime
time gain/reso DB object
DBObjPtr< CDCDedxScaleFactor > m_DBScaleFactor
Scale factor to make electrons ~1.
Debug output for CDCDedxPID module.
static CDCGeometryPar & Instance(const CDCGeometry *=nullptr)
Static method to get a reference to the CDCGeometryPar instance.
This class simply assumes a linear translation through (0,0)
float getCharge(unsigned short adcCount, const WireID &, bool, float, float)
just multiply with the conversion factor and return.
Translator mirroring the realistic Digitization.
double getDriftLength(unsigned short tdcCount, const WireID &wireID=WireID(), double timeOfFlightEstimator=0, bool leftRight=false, double z=0, double alpha=0, double theta=static_cast< double >(TMath::Pi()/2.), unsigned short adcCount=0) override
Get Drift length.
double getDriftLengthResolution(double driftLength, const WireID &wireID=WireID(), bool leftRight=false, double z=0, double alpha=0, double=static_cast< double >(TMath::Pi()/2.)) override
Get position resolution^2 corresponding to the drift length from getDriftLength of this class.
static const unsigned int c_SetSize
Number of elements (for use in array bounds etc.)
Definition Const.h:615
static const ParticleSet chargedStableSet
set of charged stable particles
Definition Const.h:618
static const ChargedStable pion
charged pion particle
Definition Const.h:661
A class to hold the geometry of a cell.
Definition LineHelper.h:186
double dx(const DedxPoint &poca, double entAng)
Calculate the path length through this cell for a track with a given DedxPoint Of Closest Approach (p...
Definition LineHelper.h:203
bool isValid()
Check if this is a valid calculation (number of intersections = 2)
Definition LineHelper.h:199
A collection of classes that are useful for making a simple path length correction to the dE/dx measu...
Definition LineHelper.h:29
bool isMC() const
Do we have generated, not real data?
static Environment & Instance()
Static method to get a reference to the Environment instance.
A Class to store the Monte Carlo particle information.
Definition MCParticle.h:32
void setDescription(const std::string &description)
Sets the description of the module.
Definition Module.cc:214
void setPropertyFlags(unsigned int propertyFlags)
Sets the flags for the module properties.
Definition Module.cc:208
Module()
Constructor.
Definition Module.cc:30
@ c_ParallelProcessingCertified
This module can be run in parallel processing mode safely (All I/O must be done through the data stor...
Definition Module.h:80
virtual void initialize() override
Initialize the module.
virtual void event() override
This method is called for each event.
void addParam(const std::string &name, T &paramVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module.
Definition Module.h:559
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
Definition Module.h:649
Abstract base class for different kinds of events.
STL namespace.