Belle II Software development
ECLClusterPSD.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9/* Own header. */
10#include <ecl/modules/eclClusterPSD/ECLClusterPSD.h>
11
12/* ECL headers. */
13#include <ecl/dataobjects/ECLCalDigit.h>
14#include <ecl/dataobjects/ECLShower.h>
15#include <ecl/geometry/ECLGeometryPar.h>
16
17/* Basf2 headers */
18#include <framework/logging/Logger.h>
19#include <framework/geometry/B2Vector3.h>
20#include <mva/dataobjects/DatabaseRepresentationOfWeightfile.h>
21#include <mva/interface/Expert.h>
22#include <mva/interface/Interface.h>
23#include <mva/interface/Weightfile.h>
24
25/* Boost headers. */
26#include <boost/algorithm/string/predicate.hpp>
27
28/* C++ headers. */
29#include <cmath>
30
31using namespace Belle2;
32
33//-----------------------------------------------------------------
34// Register the Modules
35//-----------------------------------------------------------------
36REG_MODULE(ECLClusterPSD);
37//-----------------------------------------------------------------
38// Implementation
39//-----------------------------------------------------------------
40
41// constructor
43{
44 // Set module properties
45 setDescription("Module uses offline two component fit results to compute pulse shape discrimation variables for particle identification.");
47 addParam("CrystalHadronEnergyThreshold", m_CrystalHadronEnergyThreshold,
48 "Hadron component energy threshold to identify as hadron digit.(GeV)", 0.003);
49 addParam("CrystalHadronIntensityThreshold", m_CrystalHadronIntensityThreshold,
50 "Hadron component intensity threshold to identify as hadron digit.", 0.005);
51 addParam("MVAidentifier", m_MVAidentifier, "MVA database identifier.", std::string{"eclClusterPSD_MVA"});
52}
53
54// destructor
58
59// initialize MVA weightFile
60void ECLClusterPSDModule::initializeMVAweightFile(const std::string& identifier,
61 std::unique_ptr<DBObjPtr<DatabaseRepresentationOfWeightfile>>& weightFileRepresentation)
62{
63 if (not(boost::ends_with(identifier, ".root") or boost::ends_with(identifier, ".xml"))) {
64 weightFileRepresentation = std::unique_ptr<DBObjPtr<DatabaseRepresentationOfWeightfile>>(new
66 }
68}
69
70// initialize
79
80// initialize MVA
81void ECLClusterPSDModule::initializeMVA(const std::string& identifier,
82 std::unique_ptr<DBObjPtr<DatabaseRepresentationOfWeightfile>>& weightFileRepresentation, std::unique_ptr<MVA::Expert>& expert)
83{
84 MVA::Weightfile weightfile;
85 //Load MVA weight file
86 if (weightFileRepresentation) {
87
88 if (weightFileRepresentation->hasChanged()) {
89 std::stringstream ss((*weightFileRepresentation)->m_data);
90 weightfile = MVA::Weightfile::loadFromStream(ss);
91 } else
92 return;
93 } else {
94 weightfile = MVA::Weightfile::loadFromFile(identifier);
95 }
96
97
98 auto supported_interfaces = MVA::AbstractInterface::getSupportedInterfaces();
99 MVA::GeneralOptions general_options;
100 weightfile.getOptions(general_options);
101
102 //Check number of variables in weight file
103 if (m_numMVAvariables != general_options.m_variables.size())
104 B2FATAL("Expecting " << m_numMVAvariables << " variables, found " << general_options.m_variables.size());
105
106 expert = supported_interfaces[general_options.m_method]->getExpert();
107 expert->load(weightfile);
108
109 //create new dataset
110 if (weightFileRepresentation == m_weightfile_representation) {
111 std::vector<float> dummy(general_options.m_variables.size(), 0);
112 m_dataset = std::unique_ptr<MVA::SingleDataset>(new MVA::SingleDataset(general_options, dummy, 0));
113 }
114}
115
116// begin run
121
122// evaluates mva
124{
125
126 //geometry for cell id position
128
129 auto relatedDigits = cluster->getRelationsTo<ECLCalDigit>();
130
131 //EnergyToSort vector is used for sorting digits by offline two component energy
132 std::vector<std::tuple<double, unsigned int>> EnergyToSort;
133
134 for (unsigned int iRel = 0; iRel < relatedDigits.size(); iRel++) {
135
136 const auto caldigit = relatedDigits.object(iRel);
137
138 //exclude digits without waveforms
139 const double digitChi2 = caldigit->getTwoComponentChi2();
140 if (digitChi2 < 0) continue;
141
142 ECLDsp::TwoComponentFitType digitFitType1 = caldigit->getTwoComponentFitType();
143
144 //exclude digits digits with poor chi2
145 if (digitFitType1 == ECLDsp::poorChi2) continue;
146
147 //exclude digits with diode-crossing fits
148 if (digitFitType1 == ECLDsp::photonDiodeCrossing) continue;
149
150 EnergyToSort.emplace_back(caldigit->getTwoComponentTotalEnergy(), iRel);
151
152 }
153
154 //sorting by energy
155 std::sort(EnergyToSort.begin(), EnergyToSort.end(), std::greater<>());
156
157 //get cluster position information
158 const double showerR = cluster->getR();
159 const double showerTheta = cluster->getTheta();
160 const double showerPhi = cluster->getPhi();
161
162 B2Vector3D showerPosition;
163 showerPosition.SetMagThetaPhi(showerR, showerTheta, showerPhi);
164
165 size_t input_index{0};
166 auto& input = m_dataset->m_input;
167
168 for (unsigned int digit = 0; digit < maxdigits; ++digit) {
169
170 if (digit >= EnergyToSort.size()) break;
171
172 const auto [digitEnergy, next] = EnergyToSort[digit];
173
174 const auto caldigit = relatedDigits.object(next);
175 const double digitHadronEnergy = caldigit->getTwoComponentHadronEnergy();
176 const double digitOnlineEnergy = caldigit->getEnergy();
177 const double digitWeight = relatedDigits.weight(next);
178 ECLDsp::TwoComponentFitType digitFitType1 = caldigit->getTwoComponentFitType();
179 const int digitFitType = digitFitType1;
180 const int cellId = caldigit->getCellId();
181 B2Vector3D calDigitPosition = geometry->GetCrystalPos(cellId - 1);
182 ROOT::Math::XYZVector tempP = showerPosition - calDigitPosition;
183 const double Rval = tempP.R();
184 const double theVal = tempP.Z() / tempP.R();
185 const double phiVal = cos(tempP.Phi());
186
187 input[input_index++] = theVal;
188 input[input_index++] = phiVal;
189 input[input_index++] = Rval;
190 input[input_index++] = digitOnlineEnergy;
191 input[input_index++] = digitEnergy;
192 input[input_index++] = (digitHadronEnergy / digitEnergy);
193 input[input_index++] = digitFitType;
194 input[input_index++] = digitWeight;
195
196 }
197
198 //fill remainder with defaults
199 while (input_index < input.size()) {
200 if (((input_index - 6) % 8) != 0) {
201 input[input_index++] = 0.0;
202 } else {
203 input[input_index++] = -1.0; //Fit Type
204 }
205 }
206
207 //compute mva from input variables
208 const double MVAout = m_expert->apply(*m_dataset)[0];
209
210 return MVAout;
211}
212
213
215{
216
217 for (auto& shower : m_eclShowers) {
218
219
220 auto relatedDigits = shower.getRelationsTo<ECLCalDigit>();
221
222 double cluster2CTotalEnergy = 0;
223 double cluster2CHadronEnergy = 0;
224 double numberofHadronDigits = 0;
225 double nWaveforminCluster = 0;
226
227 for (unsigned int iRel = 0; iRel < relatedDigits.size(); iRel++) {
228
229 const auto weight = relatedDigits.weight(iRel);
230
231 const auto caldigit = relatedDigits.object(iRel);
232 const double digit2CChi2 = caldigit->getTwoComponentChi2();
233
234 if (digit2CChi2 < 0) continue; //only digits with waveforms
235
236 ECLDsp::TwoComponentFitType digitFitType1 = caldigit->getTwoComponentFitType();
237
238 //exclude digits digits with poor chi2
239 if (digitFitType1 == ECLDsp::poorChi2) continue;
240
241 //exclude digits with diode-crossing fits
242 if (digitFitType1 == ECLDsp::photonDiodeCrossing) continue;
243
244 const double digit2CTotalEnergy = caldigit->getTwoComponentTotalEnergy();
245 const double digit2CHadronComponentEnergy = caldigit->getTwoComponentHadronEnergy();
246
247 cluster2CTotalEnergy += digit2CTotalEnergy;
248 cluster2CHadronEnergy += digit2CHadronComponentEnergy;
249
250 if (digit2CTotalEnergy < 0.6) {
251 if (digit2CHadronComponentEnergy > m_CrystalHadronEnergyThreshold) numberofHadronDigits += weight;
252 } else {
253 const double digitHadronComponentIntensity = digit2CHadronComponentEnergy / digit2CTotalEnergy;
254 if (digitHadronComponentIntensity > m_CrystalHadronIntensityThreshold) numberofHadronDigits += weight;
255 }
256
257 nWaveforminCluster += weight;
258
259 }
260
261 if (nWaveforminCluster > 0) {
262 if (cluster2CTotalEnergy != 0) shower.setShowerHadronIntensity(cluster2CHadronEnergy / cluster2CTotalEnergy);
263 //evaluates mva classifier only if waveforms are available in the cluster
264 const double mvaout = evaluateMVA(&shower);
265 shower.setPulseShapeDiscriminationMVA(mvaout);
266
267 shower.setNumberOfHadronDigits(numberofHadronDigits);
269
270 } else {
271 shower.setShowerHadronIntensity(0);
272 shower.setPulseShapeDiscriminationMVA(0.5);
273 shower.setNumberOfHadronDigits(0);
274 }
275 }
276}
277
278// end run
282
283// terminate
DataType Z() const
access variable Z (= .at(2) without boundary check)
Definition B2Vector3.h:435
void SetMagThetaPhi(DataType mag, DataType theta, DataType phi)
setter with mag, theta, phi
Definition B2Vector3.h:259
Class for accessing objects in the database.
Definition DBObjPtr.h:21
Class to store calibrated ECLDigits: ECLCalDigits.
Definition ECLCalDigit.h:23
virtual const char * eclShowerArrayName() const
ECLShowers array name.
StoreArray< ECLShower > m_eclShowers
ECLShower's.
std::unique_ptr< MVA::SingleDataset > m_dataset
Pointer to the current dataset.
virtual void initialize() override
Initialize variables.
virtual void event() override
event per event.
virtual void endRun() override
end run.
virtual void terminate() override
terminate.
std::unique_ptr< MVA::Expert > m_expert
Pointer to the current MVA Expert.
const unsigned int maxdigits
Max number of digits mva can include.
const unsigned int m_numMVAvariables
number of variables expected in the MVA weightfile
std::unique_ptr< DBObjPtr< DatabaseRepresentationOfWeightfile > > m_weightfile_representation
Database pointer to the Database representation of the MVA weightfile.
virtual void beginRun() override
begin run.
void initializeMVAweightFile(const std::string &identifier, std::unique_ptr< DBObjPtr< DatabaseRepresentationOfWeightfile > > &weightFileRepresentation)
initialize MVA weight file from DB
double m_CrystalHadronEnergyThreshold
hadron component energy threshold to classify as hadron.
std::string m_MVAidentifier
MVA - weight-file.
virtual const char * eclCalDigitArrayName() const
ECLCalDigits array name.
double evaluateMVA(const ECLShower *cluster)
Evaluates mva.
StoreArray< ECLCalDigit > m_eclCalDigits
ECLCalDigit's.
void initializeMVA(const std::string &identifier, std::unique_ptr< DBObjPtr< DatabaseRepresentationOfWeightfile > > &weightFileRepresentation, std::unique_ptr< MVA::Expert > &expert)
Load MVA weight file and set pointer of expert.
double m_CrystalHadronIntensityThreshold
hadron component intensity threshold to classify as hadron.
TwoComponentFitType
Offline two component fit type.
Definition ECLDsp.h:29
@ poorChi2
All offline fit attempts were greater than chi2 threshold.
Definition ECLDsp.h:30
@ photonDiodeCrossing
photon + diode template fit
Definition ECLDsp.h:33
Class to store ECL Showers.
Definition ECLShower.h:30
@ c_hasPulseShapeDiscrimination
bit 2: Shower has pulse shape discrimination variables.
Definition ECLShower.h:61
The Class for ECL Geometry Parameters.
static ECLGeometryPar * Instance()
Static method to get a reference to the ECLGeometryPar instance.
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
Definition Interface.cc:46
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Definition Interface.h:53
General options which are shared by all MVA trainings.
Definition Options.h:62
Wraps the data of a single event into a Dataset.
Definition Dataset.h:135
The Weightfile class serializes all information about a training into an xml tree.
Definition Weightfile.h:38
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
void getOptions(Options &options) const
Fills an Option object from the xml tree.
Definition Weightfile.cc:67
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
void setDescription(const std::string &description)
Sets the description of the module.
Definition Module.cc:214
void setPropertyFlags(unsigned int propertyFlags)
Sets the flags for the module properties.
Definition Module.cc:208
@ c_ParallelProcessingCertified
This module can be run in parallel processing mode safely (All I/O must be done through the data stor...
Definition Module.h:80
void addParam(const std::string &name, T &paramVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module.
Definition Module.h:559
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
Definition Module.h:649
B2Vector3< double > B2Vector3D
typedef for common usage with double
Definition B2Vector3.h:516
Common code concerning the geometry representation of the detector.
Definition CreatorBase.h:25
Abstract base class for different kinds of events.