Belle II Software development
ECLClusterPSD.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9/* Own header. */
10#include <ecl/modules/eclClusterPSD/ECLClusterPSD.h>
11
12/* ECL headers. */
13#include <ecl/dataobjects/ECLCalDigit.h>
14#include <ecl/dataobjects/ECLShower.h>
15#include <ecl/geometry/ECLGeometryPar.h>
16
17/* Basf2 headers */
18#include <framework/logging/Logger.h>
19#include <framework/geometry/B2Vector3.h>
20#include <mva/dataobjects/DatabaseRepresentationOfWeightfile.h>
21#include <mva/interface/Expert.h>
22#include <mva/interface/Interface.h>
23#include <mva/interface/Weightfile.h>
24
25/* C++ headers. */
26#include <cmath>
27
28using namespace Belle2;
29
30//-----------------------------------------------------------------
31// Register the Modules
32//-----------------------------------------------------------------
33REG_MODULE(ECLClusterPSD);
34//-----------------------------------------------------------------
35// Implementation
36//-----------------------------------------------------------------
37
38// constructor
40{
41 // Set module properties
42 setDescription("Module uses offline two component fit results to compute pulse shape discrimation variables for particle identification.");
44 addParam("CrystalHadronEnergyThreshold", m_CrystalHadronEnergyThreshold,
45 "Hadron component energy threshold to identify as hadron digit.(GeV)", 0.003);
46 addParam("CrystalHadronIntensityThreshold", m_CrystalHadronIntensityThreshold,
47 "Hadron component intensity threshold to identify as hadron digit.", 0.005);
48 addParam("MVAidentifier", m_MVAidentifier, "MVA database identifier.", std::string{"eclClusterPSD_MVA"});
49}
50
51// destructor
55
56// initialize MVA weightFile
57void ECLClusterPSDModule::initializeMVAweightFile(const std::string& identifier,
58 std::unique_ptr<DBObjPtr<DatabaseRepresentationOfWeightfile>>& weightFileRepresentation)
59{
60 if (not(identifier.ends_with(".root") or identifier.ends_with(".xml"))) {
61 weightFileRepresentation = std::unique_ptr<DBObjPtr<DatabaseRepresentationOfWeightfile>>(new
63 }
65}
66
67// initialize
76
77// initialize MVA
78void ECLClusterPSDModule::initializeMVA(const std::string& identifier,
79 std::unique_ptr<DBObjPtr<DatabaseRepresentationOfWeightfile>>& weightFileRepresentation, std::unique_ptr<MVA::Expert>& expert)
80{
81 MVA::Weightfile weightfile;
82 //Load MVA weight file
83 if (weightFileRepresentation) {
84
85 if (weightFileRepresentation->hasChanged()) {
86 std::stringstream ss((*weightFileRepresentation)->m_data);
87 weightfile = MVA::Weightfile::loadFromStream(ss);
88 } else
89 return;
90 } else {
91 weightfile = MVA::Weightfile::loadFromFile(identifier);
92 }
93
94
95 auto supported_interfaces = MVA::AbstractInterface::getSupportedInterfaces();
96 MVA::GeneralOptions general_options;
97 weightfile.getOptions(general_options);
98
99 //Check number of variables in weight file
100 if (m_numMVAvariables != general_options.m_variables.size())
101 B2FATAL("Expecting " << m_numMVAvariables << " variables, found " << general_options.m_variables.size());
102
103 expert = supported_interfaces[general_options.m_method]->getExpert();
104 expert->load(weightfile);
105
106 //create new dataset
107 if (weightFileRepresentation == m_weightfile_representation) {
108 std::vector<float> dummy(general_options.m_variables.size(), 0);
109 m_dataset = std::unique_ptr<MVA::SingleDataset>(new MVA::SingleDataset(general_options, dummy, 0));
110 }
111}
112
113// begin run
118
119// evaluates mva
121{
122
123 //geometry for cell id position
125
126 auto relatedDigits = cluster->getRelationsTo<ECLCalDigit>();
127
128 //EnergyToSort vector is used for sorting digits by offline two component energy
129 std::vector<std::tuple<double, unsigned int>> EnergyToSort;
130
131 for (unsigned int iRel = 0; iRel < relatedDigits.size(); iRel++) {
132
133 const auto caldigit = relatedDigits.object(iRel);
134
135 //exclude digits without waveforms
136 const double digitChi2 = caldigit->getTwoComponentChi2();
137 if (digitChi2 < 0) continue;
138
139 ECLDsp::TwoComponentFitType digitFitType1 = caldigit->getTwoComponentFitType();
140
141 //exclude digits digits with poor chi2
142 if (digitFitType1 == ECLDsp::poorChi2) continue;
143
144 //exclude digits with diode-crossing fits
145 if (digitFitType1 == ECLDsp::photonDiodeCrossing) continue;
146
147 EnergyToSort.emplace_back(caldigit->getTwoComponentTotalEnergy(), iRel);
148
149 }
150
151 //sorting by energy
152 std::sort(EnergyToSort.begin(), EnergyToSort.end(), std::greater<>());
153
154 //get cluster position information
155 const double showerR = cluster->getR();
156 const double showerTheta = cluster->getTheta();
157 const double showerPhi = cluster->getPhi();
158
159 B2Vector3D showerPosition;
160 showerPosition.SetMagThetaPhi(showerR, showerTheta, showerPhi);
161
162 size_t input_index{0};
163 auto& input = m_dataset->m_input;
164
165 for (unsigned int digit = 0; digit < maxdigits; ++digit) {
166
167 if (digit >= EnergyToSort.size()) break;
168
169 const auto [digitEnergy, next] = EnergyToSort[digit];
170
171 const auto caldigit = relatedDigits.object(next);
172 const double digitHadronEnergy = caldigit->getTwoComponentHadronEnergy();
173 const double digitOnlineEnergy = caldigit->getEnergy();
174 const double digitWeight = relatedDigits.weight(next);
175 ECLDsp::TwoComponentFitType digitFitType1 = caldigit->getTwoComponentFitType();
176 const int digitFitType = digitFitType1;
177 const int cellId = caldigit->getCellId();
178 B2Vector3D calDigitPosition = geometry->GetCrystalPos(cellId - 1);
179 ROOT::Math::XYZVector tempP = showerPosition - calDigitPosition;
180 const double Rval = tempP.R();
181 const double theVal = tempP.Z() / tempP.R();
182 const double phiVal = cos(tempP.Phi());
183
184 input[input_index++] = theVal;
185 input[input_index++] = phiVal;
186 input[input_index++] = Rval;
187 input[input_index++] = digitOnlineEnergy;
188 input[input_index++] = digitEnergy;
189 input[input_index++] = (digitHadronEnergy / digitEnergy);
190 input[input_index++] = digitFitType;
191 input[input_index++] = digitWeight;
192
193 }
194
195 //fill remainder with defaults
196 while (input_index < input.size()) {
197 if (((input_index - 6) % 8) != 0) {
198 input[input_index++] = 0.0;
199 } else {
200 input[input_index++] = -1.0; //Fit Type
201 }
202 }
203
204 //compute mva from input variables
205 const double MVAout = m_expert->apply(*m_dataset)[0];
206
207 return MVAout;
208}
209
210
212{
213
214 for (auto& shower : m_eclShowers) {
215
216
217 auto relatedDigits = shower.getRelationsTo<ECLCalDigit>();
218
219 double cluster2CTotalEnergy = 0;
220 double cluster2CHadronEnergy = 0;
221 double numberofHadronDigits = 0;
222 double nWaveforminCluster = 0;
223
224 for (unsigned int iRel = 0; iRel < relatedDigits.size(); iRel++) {
225
226 const auto weight = relatedDigits.weight(iRel);
227
228 const auto caldigit = relatedDigits.object(iRel);
229 const double digit2CChi2 = caldigit->getTwoComponentChi2();
230
231 if (digit2CChi2 < 0) continue; //only digits with waveforms
232
233 ECLDsp::TwoComponentFitType digitFitType1 = caldigit->getTwoComponentFitType();
234
235 //exclude digits digits with poor chi2
236 if (digitFitType1 == ECLDsp::poorChi2) continue;
237
238 //exclude digits with diode-crossing fits
239 if (digitFitType1 == ECLDsp::photonDiodeCrossing) continue;
240
241 const double digit2CTotalEnergy = caldigit->getTwoComponentTotalEnergy();
242 const double digit2CHadronComponentEnergy = caldigit->getTwoComponentHadronEnergy();
243
244 cluster2CTotalEnergy += digit2CTotalEnergy;
245 cluster2CHadronEnergy += digit2CHadronComponentEnergy;
246
247 if (digit2CTotalEnergy < 0.6) {
248 if (digit2CHadronComponentEnergy > m_CrystalHadronEnergyThreshold) numberofHadronDigits += weight;
249 } else {
250 const double digitHadronComponentIntensity = digit2CHadronComponentEnergy / digit2CTotalEnergy;
251 if (digitHadronComponentIntensity > m_CrystalHadronIntensityThreshold) numberofHadronDigits += weight;
252 }
253
254 nWaveforminCluster += weight;
255
256 }
257
258 if (nWaveforminCluster > 0) {
259 if (cluster2CTotalEnergy != 0) shower.setShowerHadronIntensity(cluster2CHadronEnergy / cluster2CTotalEnergy);
260 //evaluates mva classifier only if waveforms are available in the cluster
261 const double mvaout = evaluateMVA(&shower);
262 shower.setPulseShapeDiscriminationMVA(mvaout);
263
264 shower.setNumberOfHadronDigits(numberofHadronDigits);
266
267 } else {
268 shower.setShowerHadronIntensity(0);
269 shower.setPulseShapeDiscriminationMVA(0.5);
270 shower.setNumberOfHadronDigits(0);
271 }
272 }
273}
274
275// end run
279
280// terminate
DataType Z() const
access variable Z (= .at(2) without boundary check)
Definition B2Vector3.h:435
void SetMagThetaPhi(DataType mag, DataType theta, DataType phi)
setter with mag, theta, phi
Definition B2Vector3.h:259
Class for accessing objects in the database.
Definition DBObjPtr.h:21
Class to store calibrated ECLDigits: ECLCalDigits.
Definition ECLCalDigit.h:23
virtual const char * eclShowerArrayName() const
ECLShowers array name.
StoreArray< ECLShower > m_eclShowers
ECLShower's.
std::unique_ptr< MVA::SingleDataset > m_dataset
Pointer to the current dataset.
virtual void initialize() override
Initialize variables.
virtual void event() override
event per event.
virtual void endRun() override
end run.
virtual void terminate() override
terminate.
std::unique_ptr< MVA::Expert > m_expert
Pointer to the current MVA Expert.
const unsigned int maxdigits
Max number of digits mva can include.
const unsigned int m_numMVAvariables
number of variables expected in the MVA weightfile
std::unique_ptr< DBObjPtr< DatabaseRepresentationOfWeightfile > > m_weightfile_representation
Database pointer to the Database representation of the MVA weightfile.
virtual void beginRun() override
begin run.
void initializeMVAweightFile(const std::string &identifier, std::unique_ptr< DBObjPtr< DatabaseRepresentationOfWeightfile > > &weightFileRepresentation)
initialize MVA weight file from DB
double m_CrystalHadronEnergyThreshold
hadron component energy threshold to classify as hadron.
std::string m_MVAidentifier
MVA - weight-file.
virtual const char * eclCalDigitArrayName() const
ECLCalDigits array name.
double evaluateMVA(const ECLShower *cluster)
Evaluates mva.
StoreArray< ECLCalDigit > m_eclCalDigits
ECLCalDigit's.
void initializeMVA(const std::string &identifier, std::unique_ptr< DBObjPtr< DatabaseRepresentationOfWeightfile > > &weightFileRepresentation, std::unique_ptr< MVA::Expert > &expert)
Load MVA weight file and set pointer of expert.
double m_CrystalHadronIntensityThreshold
hadron component intensity threshold to classify as hadron.
TwoComponentFitType
Offline two component fit type.
Definition ECLDsp.h:29
@ poorChi2
All offline fit attempts were greater than chi2 threshold.
Definition ECLDsp.h:30
@ photonDiodeCrossing
photon + diode template fit
Definition ECLDsp.h:33
Class to store ECL Showers.
Definition ECLShower.h:30
@ c_hasPulseShapeDiscrimination
bit 2: Shower has pulse shape discrimination variables.
Definition ECLShower.h:61
The Class for ECL Geometry Parameters.
static ECLGeometryPar * Instance()
Static method to get a reference to the ECLGeometryPar instance.
static void initSupportedInterfaces()
Static function which initializes all supported interfaces, has to be called once before getSupported...
Definition Interface.cc:46
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Definition Interface.h:53
General options which are shared by all MVA trainings.
Definition Options.h:62
Wraps the data of a single event into a Dataset.
Definition Dataset.h:135
The Weightfile class serializes all information about a training into an xml tree.
Definition Weightfile.h:38
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
void setDescription(const std::string &description)
Sets the description of the module.
Definition Module.cc:214
void setPropertyFlags(unsigned int propertyFlags)
Sets the flags for the module properties.
Definition Module.cc:208
@ c_ParallelProcessingCertified
This module can be run in parallel processing mode safely (All I/O must be done through the data stor...
Definition Module.h:80
void addParam(const std::string &name, T &paramVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module.
Definition Module.h:559
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
Definition Module.h:649
B2Vector3< double > B2Vector3D
typedef for common usage with double
Definition B2Vector3.h:516
Common code concerning the geometry representation of the detector.
Definition CreatorBase.h:25
Abstract base class for different kinds of events.