Belle II Software development
config.py
1#!/usr/bin/env python3
2
3
10
11# @cond SUPPRESS_DOXYGEN
12
13"""
14 Configuration classes
15
16 The classes defined here are used to uniquely define a FEI training.
17 Meaning:
18 - The global configuration like database prefix, cache mode, monitoring, ... (FeiConfiguration)
19 - The reconstructed Particles (Particle)
20 - The reconstructed Channels of each particle (DecayChannel)
21 - The MVA configuration for each channel (MVAConfiguration)
22 - The Cut definitions of each channel (PreCutConfiguration)
23 - The Cut definitions of each particle (PostCutConfiguration)
24"""
25
26import collections
27import copy
28import re
29import itertools
30import typing
31import basf2
32
33# Define classes at top level to make them pickable
34# Creates new class via namedtuple, which are like a struct in C
35
36FeiConfiguration = collections.namedtuple('FeiConfiguration',
37 'prefix, cache, monitor, legacy, externTeacher, training, roundMode, monitoring_path')
38FeiConfiguration.__new__.__defaults__ = ('FEI_TEST', None, True, None, 'basf2_mva_teacher', False, 0, '')
39FeiConfiguration.__doc__ = "Fei Global Configuration class"
40FeiConfiguration.prefix.__doc__ = "The database prefix used for all weight files"
41FeiConfiguration.cache.__doc__ = "The stage which is passed as input, it is assumed that all previous stages"\
42 " do not have to be reconstructed again. Can be either a number or"\
43 " a filename containing a pickled number or"\
44 " None in this case the environment variable FEI_STAGE is used."
45FeiConfiguration.monitor.__doc__ = (
46 "Determines the level of monitoring histograms to create. "
47 "Set to False to disable monitoring. "
48 "Set to 'simple' to enable lightweight histograms. "
49 "Any other value will enable full monitoring histograms."
50)
51FeiConfiguration.legacy.__doc__ = "Pass the summary file of a legacy FEI training,"\
52 " and the algorithm will be able to apply this training."
53FeiConfiguration.externTeacher.__doc__ = "Teacher command e.g. basf2_mva_teacher, b2mva-kekcc-cluster-teacher"
54FeiConfiguration.training.__doc__ = "If you train the FEI set this to True, otherwise to False"
55FeiConfiguration.roundMode.__doc__ = "Round mode for the training. 0 default, 1 resuming, 2 finishing, 3 retraining."
56FeiConfiguration.monitoring_path.__doc__ = "Path where monitoring histograms are stored."
57
58
59MVAConfiguration = collections.namedtuple('MVAConfiguration', 'method, config, variables, target, sPlotVariable, spectators')
60MVAConfiguration.__new__.__defaults__ = ('FastBDT',
61 '--nTrees 400 --nCutLevels 10 --nLevels 3 --shrinkage 0.1 --randRatio 0.5',
62 None, 'isSignal', None, {})
63MVAConfiguration.__doc__ = "Multivariate analysis configuration class."
64MVAConfiguration.method.__doc__ = "Method used by MVAInterface."
65MVAConfiguration.config.__doc__ = "Method specific configuration string passed to basf2_mva_teacher"
66MVAConfiguration.variables.__doc__ = "List of variables from the VariableManager."\
67 " {} is expanded to one variable per daughter particle."
68MVAConfiguration.target.__doc__ = "Target variable from the VariableManager."
69MVAConfiguration.sPlotVariable.__doc__ = "Discriminating variable used by sPlot to do data-driven training."
70MVAConfiguration.spectators.__doc__ = "Dictionary of spectator variables with their ranges from the VariableManager."
71
72
73PreCutConfiguration = collections.namedtuple(
74 'PreCutConfiguration', 'userCut, vertexCut, noBackgroundSampling,'
75 'bestCandidateVariable, bestCandidateCut, bestCandidateMode, noSignalSampling, bkgSamplingFactor')
76PreCutConfiguration.__new__.__defaults__ = ('', -2, False, None, 0, 'lowest', False, 1.0)
77PreCutConfiguration.__doc__ = "PreCut configuration class. These cuts is employed before training the mva classifier."
78PreCutConfiguration.userCut.__doc__ = "The user cut is passed directly to the ParticleCombiner."\
79 " Particles which do not pass this cut are immediately discarded."
80PreCutConfiguration.vertexCut.__doc__ = "The vertex cut is passed as confidence level to the VertexFitter."
81PreCutConfiguration.noBackgroundSampling.__doc__ = "For very pure channels, the background sampling factor is too high" \
82 " and the MVA can't be trained. This disables background sampling."
83PreCutConfiguration.bestCandidateVariable.__doc__ = "Variable from the VariableManager which is used to rank all candidates."
84PreCutConfiguration.bestCandidateCut.__doc__ = "Number of best-candidates to keep after the best-candidate ranking."
85PreCutConfiguration.bestCandidateMode.__doc__ = "Either lowest or highest."
86PreCutConfiguration.noSignalSampling.__doc__ = "For channels with unknown br. frac., the signal sampling factor can be" \
87 " overestimated and you loose signal samples in the training." \
88 " This disables signal sampling."
89PreCutConfiguration.bkgSamplingFactor.__doc__ = "Add additional multiplicative bkg. sampling factor, less than 1.0 to reduce."
90
91PostCutConfiguration = collections.namedtuple('PostCutConfiguration', 'value, bestCandidateCut')
92PostCutConfiguration.__new__.__defaults__ = (0.0, 0)
93PostCutConfiguration.__doc__ = "PostCut configuration class. This cut is employed after the training of the mva classifier."
94PostCutConfiguration.value.__doc__ = "Absolute value used to cut on the SignalProbability of each candidate."
95PostCutConfiguration.bestCandidateCut.__doc__ = "Number of best-candidates to keep, ranked by SignalProbability."
96
97DecayChannel = collections.namedtuple(
98 'DecayChannel',
99 'name, label, decayString, daughters, mvaConfig, preCutConfig, decayModeID, pi0veto')
100DecayChannel.__new__.__defaults__ = (None, None, None, None, None, None, None, False)
101DecayChannel.__doc__ = "Decay channel of a Particle."
102DecayChannel.name.__doc__ = "str:Name of the channel e.g. :code:`D0:generic_0`"
103DecayChannel.label.__doc__ = "Label used to identify the decay channel e.g. for weight files independent of decayModeID"
104DecayChannel.decayString.__doc__ = "DecayDescriptor of the channel e.g. D0 -> K+ pi-"
105DecayChannel.daughters.__doc__ = "List of daughter particles of the decay channel e.g. [K+, pi-]"
106DecayChannel.mvaConfig.__doc__ = "MVAConfiguration object which is used for this channel."
107DecayChannel.preCutConfig.__doc__ = "PreCutConfiguration object which is used for this channel."
108DecayChannel.decayModeID.__doc__ = "DecayModeID of this channel. Unique ID for each channel of this particle."
109DecayChannel.pi0veto.__doc__ = "If true, additional pi0veto variables are added to the MVAs, useful only for decays with gammas."
110
111MonitoringVariableBinning = {'mcErrors': ('mcErrors', 513, -0.5, 512.5),
112 'mcParticleStatus': ('mcParticleStatus', 257, -0.5, 256.5),
113 'dM': ('dM', 100, -1.0, 1.0),
114 'dQ': ('dQ', 100, -1.0, 1.0),
115 'abs(dM)': ('abs(dM)', 100, 0.0, 1.0),
116 'abs(dQ)': ('abs(dQ)', 100, 0.0, 1.0),
117 'pionID': ('pionID', 100, 0.0, 1.0),
118 'kaonID': ('kaonID', 100, 0.0, 1.0),
119 'protonID': ('protonID', 100, 0.0, 1.0),
120 'electronID': ('electronID', 100, 0.0, 1.0),
121 'muonID': ('muonID', 100, 0.0, 1.0),
122 'isSignal': ('isSignal', 2, -0.5, 1.5),
123 'isSignalAcceptMissingNeutrino': ('isSignalAcceptMissingNeutrino', 2, -0.5, 1.5),
124 'isPrimarySignal': ('isPrimarySignal', 2, -0.5, 1.5),
125 'chiProb': ('chiProb', 100, 0.0, 1.0),
126 'Mbc': ('Mbc', 100, 5.1, 5.4),
127 'cosThetaBetweenParticleAndNominalB': ('cosThetaBetweenParticleAndNominalB', 100, -10.0, 10.0),
128 'extraInfo(SignalProbability)': ('extraInfo(SignalProbability)', 100, 0.0, 1.0),
129 'extraInfo(decayModeID)': ('extraInfo(decayModeID)', 101, -0.5, 100.5),
130 'extraInfo(uniqueSignal)': ('extraInfo(uniqueSignal)', 2, -0.5, 1.5),
131 'extraInfo(preCut_rank)': ('extraInfo(preCut_rank)', 41, -0.5, 40.5),
132 'extraInfo(postCut_rank)': ('extraInfo(postCut_rank)', 41, -0.5, 40.5),
133 'daughterProductOf(extraInfo(SignalProbability))':
134 ('daughterProductOf(extraInfo(SignalProbability))', 100, 0.0, 1.0),
135 'pValueCombinationOfDaughters(extraInfo(SignalProbability))':
136 ('pValueCombinationOfDaughters(extraInfo(SignalProbability))', 100, 0.0, 1.0),
137 }
138
139
140def variables2binnings(variables):
141 """
142 Convert given variables into a tuples which can be given to VariableToHistogram
143 """
144 return [MonitoringVariableBinning[v] if v in MonitoringVariableBinning else (v, 100, -10.0, 10.0) for v in variables]
145
146
147def variables2binnings_2d(variables):
148 """
149 Convert given variables into a tuples which can be given to VariableToHistogram
150 """
151 result = []
152 for v1, v2 in variables:
153 b1 = MonitoringVariableBinning[v1] if v1 in MonitoringVariableBinning else (v1, 100, -10.0, 10.0)
154 b2 = MonitoringVariableBinning[v2] if v2 in MonitoringVariableBinning else (v2, 100, -10.0, 10.0)
155 result.append(b1 + b2)
156 return result
157
158
159def removeJPsiSlash(string: str) -> str:
160 """
161 Remove the / in the J/psi particle name
162 """
163 return string.replace('/', '')
164
165
166class Particle:
167
168 """
169 The Particle class is the only class the end-user gets into contact with.
170 The user creates an instance of this class for every particle he wants to reconstruct with the FEI algorithm,
171 and provides MVAConfiguration, PreCutConfiguration and PostCutConfiguration. These can be overwritten per channel.
172 """
173
174 def __init__(self, identifier: str,
175 mvaConfig: MVAConfiguration,
176 preCutConfig: PreCutConfiguration = PreCutConfiguration(),
177 postCutConfig: PostCutConfiguration = PostCutConfiguration()):
178 """
179 Creates a Particle without any decay channels. To add decay channels use addChannel method.
180 @param identifier is the pdg name of the particle as a string
181 with an optional additional user label separated by ':'
182 @param mvaConfig multivariate analysis configuration
183 @param preCutConfig intermediate pre cut configuration
184 @param postCutConfig post cut configuration
185 """
186
187 self.identifier = identifier + ':generic' if len(identifier.split(':')) < 2 else identifier
188 v = self.identifier.split(':')
189
190 self.name = v[0]
191
192 self.label = v[1]
193
194 self.mvaConfig = mvaConfig
195
196 self.channels = []
197
198 self.preCutConfig = preCutConfig
199
200 self.postCutConfig = postCutConfig
201
202 def __eq__(self, a):
203 """
204 Compares to Particle objects.
205 They are equal if their identifier, name, label, all channels, preCutConfig and postCutConfig is equal
206 @param a another Particle object
207 """
208 return (self.identifier == a.identifier and self.name == a.name and self.label == a.label and
209 self.channels == a.channels and self.preCutConfig == a.preCutConfig and self.postCutConfig == a.postCutConfig)
210
211 def __str__(self):
212 """
213 Creates a string representation of a Particle object.
214 """
215 return str((self.identifier, self.channels, self.preCutConfig, self.postCutConfig, self.mvaConfig))
216
217 def __hash__(self):
218 """
219 Creates a hash of a Particle object.
220 This is necessary to use this as a key in a dictionary
221 """
222 return hash((self.identifier, self.channels, self.preCutConfig, self.postCutConfig, self.mvaConfig))
223
224 @property
225 def daughters(self):
226 """ Property returning list of unique daughter particles of all channels """
227 return list(frozenset([daughter for channel in self.channels for daughter in channel.daughters]))
228
229 def addChannel(self,
230 daughters: typing.Sequence[str],
231 mvaConfig: MVAConfiguration = None,
232 preCutConfig: PreCutConfiguration = None,
233 pi0veto: bool = False):
234 """
235 Appends a new decay channel to the Particle object.
236 @param daughters is a list of pdg particle names e.g. ['pi+','K-']
237 @param mvaConfig multivariate analysis configuration
238 @param preCutConfig pre cut configuration object
239 @param pi0veto if true, additional pi0veto variables are added to the MVA configuration
240 """
241 # Append generic label to all defined daughters if no label was set yet
242 daughters = [d + ':generic' if ':' not in d else d for d in daughters]
243 # Use default mvaConfig of this particle if no channel-specific config is given
244 mvaConfig = copy.deepcopy(self.mvaConfig if mvaConfig is None else mvaConfig)
245 # Use default preCutConfig of this particle if no channel-specific config is given
246 preCutConfig = copy.deepcopy(self.preCutConfig if preCutConfig is None else preCutConfig)
247 # At the moment all channels must have the same target variable. Why?
248 if mvaConfig is not None and mvaConfig.target != self.mvaConfig.target:
249 basf2.B2FATAL(
250 f'Particle {self.identifier} has common target {self.mvaConfig.target}, while channel '
251 f'{" ".join(daughters)} has {mvaConfig.target}. Each particle must have exactly one target!')
252 # Replace generic-variables with ordinary variables.
253 # All instances of {} are replaced with all combinations of daughter indices
254 mvaVars = []
255 for v in mvaConfig.variables:
256 if v.count('{') == 0:
257 mvaVars.append(v)
258 continue
259 matches = re.findall(r'\{\s*\d*\s*\.\.\s*\d*\s*\}', v)
260 if len(matches) == 0 and v.count('{}') == 0:
261 mvaVars.append(v)
262 elif v.count('{}') > 0 and len(matches) > 0:
263 basf2.B2FATAL(f'Variable {v} contains both '+'{}'+f' and {matches}. Only one is allowed!')
264 elif len(matches) > 0:
265 ranges = []
266 skip = False
267 for match in matches:
268 tempRange = match[1:-1].split('..')
269 if tempRange[0] == '':
270 tempRange[0] = 0
271 else:
272 tempRange[0] = int(tempRange[0])
273 if tempRange[0] >= len(daughters):
274 basf2.B2DEBUG(11, f'Variable {v} contains index {tempRange[0]} which is more than daughters, skipping!')
275 skip = True
276 break
277 if tempRange[1] == '':
278 tempRange[1] = len(daughters)
279 else:
280 tempRange[1] = int(tempRange[1])
281 if tempRange[1] > len(daughters):
282 basf2.B2DEBUG(11, f'Variable {v} contains index {tempRange[1]} which is more than daughters, skipping!')
283 skip = True
284 break
285 ranges.append(tempRange)
286 if skip:
287 continue
288 if len(ranges) == 1:
289 mvaVars += [v.replace(matches[0], str(c)) for c in range(ranges[0][0], ranges[0][1])]
290 else:
291 for match in matches:
292 v = v.replace(match, '{}')
293 mvaVars += [v.format(*c) for c in itertools.product(*[range(r[0], r[1]) for r in ranges])]
294 elif v.count('{}') <= len(daughters):
295 mvaVars += [v.format(*c) for c in itertools.combinations(list(range(0, len(daughters))), v.count('{}'))]
296 elif v.count('{}') > len(daughters):
297 basf2.B2DEBUG(11, f'Variable {v} contains more brackets than daughters, which is why it will be ignored!')
298 continue
299 else:
300 basf2.B2FATAL(f'Something went wrong with variable {v}!')
301 mvaConfig = mvaConfig._replace(variables=mvaVars)
302 # Add new channel
303 decayModeID = len(self.channels)
304 self.channels.append(DecayChannel(name=self.identifier + '_' + str(decayModeID),
305 label=removeJPsiSlash(self.identifier + ' ==> ' + ' '.join(daughters)),
306 decayString=self.identifier + '_' + str(decayModeID) + ' -> ' + ' '.join(daughters),
307 daughters=daughters,
308 mvaConfig=mvaConfig,
309 preCutConfig=preCutConfig,
310 decayModeID=decayModeID,
311 pi0veto=pi0veto))
312 return self
313
314# @endcond