Belle II Software development
flavorTagger.py
1# This is the main flavortagger function
2
3import basf2
4from basf2 import B2INFO, B2FATAL, B2WARNING
5import modularAnalysis as ma
6from variables import variables as va
7from flavorTagger.utils import (
8 get_available_categories, set_FT_pid_aliases, read_yaml,
9 get_Belle_or_Belle2, set_GFlat_aliases, set_masked_vars,
10 fill_particle_lists, set_output_vars
11)
12from .event_level import event_level
13from .combiner_level import combiner_level
14import os
15
16
17def flavorTagger(
18 particleLists=None,
19 weightFiles='B2nunubarBGx1',
20 workingDirectory='.',
21 combinerMethods=['TMVA-FBDT'],
22 categories=[
23 'Electron',
24 'IntermediateElectron',
25 'Muon',
26 'IntermediateMuon',
27 'KinLepton',
28 'IntermediateKinLepton',
29 'Kaon',
30 'SlowPion',
31 'FastHadron',
32 'Lambda',
33 'FSC',
34 'MaximumPstar',
35 'KaonPion'],
36 maskName='FTDefaultMask',
37 saveCategoriesInfo=True,
38 useOnlyLocalWeightFiles=False,
39 downloadFromDatabaseIfNotFound=False,
40 prefix='MC15ri_light-2207-bengal_0',
41 useGNN=True,
42 identifierGNN='GFlaT_MC15ri_light_2303_iriomote_0',
43 usePIDNN=False,
44 path=None,
45):
46 """
47 Defines the whole flavor tagging process for each selected Rest of Event (ROE) built in the steering file.
48 The flavor is predicted by Multivariate Methods trained with Variables and MetaVariables which use
49 Tracks, ECL- and KLMClusters from the corresponding RestOfEvent dataobject.
50 This module can be used to sample the training information, to train and/or to test the flavorTagger.
51
52 @param particleLists The ROEs for flavor tagging are selected from the given particle lists.
53 @param weightFiles Weight files name. Default=
54 ``B2nunubarBGx1`` (official weight files). If the user wants to train the
55 FlavorTagger themselves, the weightfiles name should correspond to the
56 analyzed CP channel in order to avoid confusions. The default name
57 ``B2nunubarBGx1`` corresponds to
58 :math:`B^0_{\\rm sig}\\to \\nu \\overline{\\nu}`.
59 and ``B2JpsiKs_muBGx1`` to
60 :math:`B^0_{\\rm sig}\\to J/\\psi (\\to \\mu^+ \\mu^-) K_s (\\to \\pi^+ \\pi^-)`.
61 BGx1 stands for events simulated with background.
62 @param workingDirectory Path to the directory containing the FlavorTagging/ folder.
63 @param combinerMethods MVAs for the combiner: ``TMVA-FBDT` (default).
64 ``FANN-MLP`` is available only with ``prefix=''`` (MC13 weight files).
65 @param categories Categories used for flavor tagging. By default all are used.
66 @param maskName Gets ROE particles from a specified ROE mask.
67 ``FTDefaultMask`` (default): tentative mask definition that will be created
68 automatically. The definition is as follows:
69
70 - Track (pion): thetaInCDCAcceptance and dr<1 and abs(dz)<3
71 - ECL-cluster (gamma): thetaInCDCAcceptance and clusterNHits>1.5 and \
72 [[clusterReg==1 and E>0.08] or [clusterReg==2 and E>0.03] or \
73 [clusterReg==3 and E>0.06]] \
74 (Same as gamma:pi0eff30_May2020 and gamma:pi0eff40_May2020)
75
76 ``all``: all ROE particles are used.
77 Or one can give any mask name defined before calling this function.
78 @param saveCategoriesInfo Sets to save information of individual categories.
79 @param useOnlyLocalWeightFiles [Expert] Uses only locally saved weight files.
80 @param downloadFromDatabaseIfNotFound [Expert] Weight files are downloaded from
81 the conditions database if not available in workingDirectory.
82 @param prefix Prefix of weight files.
83 ``MC15ri_light-2207-bengal_0`` (default): Weight files trained with MC15ri samples.
84 ``MC16rd_light-2501-betelgeuse``: Weight files trained with MC16rd samples.
85 ``''``: Weight files trained for MC13 samples.
86 @param useGNN Use GNN-based Flavor Tagger in addition with FastBDT-based one.
87 Please specify the weight file with the option ``identifierGNN``.
88 [Expert] In the sampler mode,
89 training files for GNN-based Flavor Tagger are produced.
90 @param identifierGNN The name of weight file of the GNN-based Flavor Tagger.
91 ``GFlaT_MC15ri_light_2303_iriomote_0`` (default): Trained with MC15ri samples
92 ``GFlaT_MC16rd_light-2501-betelgeuse_tensorflow``: Trained with MC16rd samples
93 [Expert] Multiple identifiers can be given with list(str).
94 @param usePIDNN If True, PID probabilities calculated from PID neural network are used
95 (default is False). Prefix and identifierGNN must be set accordingly.
96 @param path Modules are added to this path
97
98 """
99
100 # set common config
101 exp_type = get_Belle_or_Belle2()
102 available_categories = get_available_categories()
103 config_filepath = basf2.find_file('data/analysis/config.yaml')
104 config_params = read_yaml(config_filepath)
105 signal_fraction = config_params["signal_fraction"]
106 download_folder = config_params["database_io"]["download_folder"]
107
108 # sanitize and cross check inputs
109
110 # force the particle list to be a list
111 if (not isinstance(particleLists, list)):
112 particleLists = [particleLists]
113
114 # ensure unique categories in list
115 if len(categories) != len(set(categories)):
116 dup = [cat for cat in set(categories) if categories.count(cat) > 1]
117 B2WARNING(f"Flavor Tagger: There are duplicate elements in the given\
118 categories list. The following duplicate elements are\
119 removed: {', '.join(dup)}")
120 categories = list(set(categories))
121
122 # ensure minimum category number
123 if len(categories) < 2:
124 B2FATAL('Flavor Tagger: At least two categories are needed.')
125 B2FATAL(f"Flavor Tagger: Possible categories are {available_categories.keys()}")
126
127 # ensure legitimate categories
128 for category in categories:
129 if category not in available_categories:
130 B2FATAL('Flavor Tagger: ' + category + ' is not a valid category name given')
131 B2FATAL(f"Flavor Tagger: Possible categories are {available_categories.keys()}")
132
133 # ensure correct GNN config
134 if useGNN and identifierGNN == '':
135 B2FATAL('Please specify the name of the weight file with ``identifierGNN``')
136
137 # ensure correct combiner method config
138 if len(combinerMethods) == 0:
139 B2FATAL('Flavor Tagger: Please specify at least one combinerMethods.\
140 The available methods are "TMVA-FBDT" and "FANN-MLP"')
141
142 FANNmlp = False
143 TMVAfbdt = False
144
145 for method in combinerMethods:
146 if method == 'TMVA-FBDT':
147 TMVAfbdt = True
148 elif method == 'FANN-MLP':
149 FANNmlp = True
150 else:
151 B2FATAL('Flavor Tagger: Invalid list of combinerMethods. \
152 The available methods are "TMVA-FBDT" and "FANN-MLP"')
153
154 # check if working directory exists for download
155 basf2.find_file(workingDirectory)
156
157 files_dir = f"{workingDirectory}/{download_folder}"
158 if downloadFromDatabaseIfNotFound:
159 if not basf2.find_file(files_dir, silent=True):
160 os.makedirs(files_dir)
161
162 # verbose
163 B2INFO(f"""
164 *** FLAVOR TAGGING ***
165 Working directory is: {files_dir}
166 """)
167
168 # setup FT pid alias
169 if prefix == '':
170 set_FT_pid_aliases(type="MC13", exp_type=exp_type)
171 else:
172 set_FT_pid_aliases(type="Current", exp_type=exp_type)
173 weightFiles = f"{prefix}_{weightFiles}"
174
175 # set GNN aliases
176 if useGNN:
177 set_GFlat_aliases(categories, usePIDNN)
178
179 # set input masked vars
180 set_masked_vars()
181
182 # Create configuration lists and code-name for given category's list
183 trackLevelParticleLists = []
184 eventLevelParticleLists = []
185 variablesCombinerLevel = []
186 categoriesCombination = []
187 categoriesCombinationCode = 'CatCode'
188 for category in categories:
189 ftCategory = available_categories[category]
190
191 track_tuple = (ftCategory.particleList, ftCategory.trackName)
192 event_tuple = (ftCategory.particleList, ftCategory.eventName, ftCategory.variableName)
193
194 if track_tuple not in trackLevelParticleLists and category != 'MaximumPstar':
195 trackLevelParticleLists.append(track_tuple)
196
197 if event_tuple not in eventLevelParticleLists:
198 eventLevelParticleLists.append(event_tuple)
199 variablesCombinerLevel.append(ftCategory.variableName)
200 categoriesCombination.append(ftCategory.code)
201 else:
202 B2FATAL(f"Flavor Tagger: {category} has been already given")
203
204 for code in sorted(categoriesCombination):
205 categoriesCombinationCode = categoriesCombinationCode + f'{int(code):02}'
206
207 # Create default ROE-mask
208 if maskName == 'FTDefaultMask':
209 FTDefaultMask = (
210 'FTDefaultMask',
211 'thetaInCDCAcceptance and dr<1 and abs(dz)<3',
212 'thetaInCDCAcceptance and clusterNHits>1.5 and \
213 [[E>0.08 and clusterReg==1] or [E>0.03 and clusterReg==2] or [E>0.06 and clusterReg==3]]'
214 )
215 for name in particleLists:
216 ma.appendROEMasks(list_name=name, mask_tuples=[FTDefaultMask], path=path)
217
218 # Start ROE-routine
219 roe_path = basf2.create_path()
220 deadEndPath = basf2.create_path()
221
222 # If trigger returns 1 jump into empty path skipping further modules in roe_path
223 # run filter with no cut first to get rid of ROEs that are missing the mask of the signal particle
224 ma.signalSideParticleListsFilter(particleLists, f'nROE_Charged({maskName}, 0) > 0', roe_path, deadEndPath)
225
226 # Initialization of flavorTaggerInfo dataObject needs to be done in the main path
227 flavorTaggerInfoBuilder = basf2.register_module('FlavorTaggerInfoBuilder')
228 path.add_module(flavorTaggerInfoBuilder)
229
230 # fill particle lists
231 fill_particle_lists(maskName, categories, roe_path)
232
233 if event_level(
234 weightFiles=weightFiles,
235 categories=categories,
236 files_dir=files_dir,
237 useOnlyLocalFlag=useOnlyLocalWeightFiles,
238 downloadFlag=downloadFromDatabaseIfNotFound,
239 exp_type=exp_type,
240 signal_fraction=signal_fraction,
241 path=roe_path
242 ):
243
245 weightFiles=weightFiles,
246 categories=categories,
247 variablesCombinerLevel=variablesCombinerLevel,
248 categoriesCombinationCode=categoriesCombinationCode,
249 TMVAfbdt=TMVAfbdt,
250 FANNmlp=FANNmlp,
251 downloadFlag=downloadFromDatabaseIfNotFound,
252 useOnlyLocalFlag=useOnlyLocalWeightFiles,
253 signal_fraction=signal_fraction,
254 filesDirectory=files_dir,
255 path=roe_path
256 )
257
258 flavorTaggerInfoFiller = basf2.register_module('FlavorTaggerInfoFiller')
259 flavorTaggerInfoFiller.param('trackLevelParticleLists', trackLevelParticleLists)
260 flavorTaggerInfoFiller.param('eventLevelParticleLists', eventLevelParticleLists)
261 flavorTaggerInfoFiller.param('TMVAfbdt', TMVAfbdt)
262 flavorTaggerInfoFiller.param('FANNmlp', FANNmlp)
263 flavorTaggerInfoFiller.param('qpCategories', saveCategoriesInfo)
264 flavorTaggerInfoFiller.param('istrueCategories', saveCategoriesInfo)
265 flavorTaggerInfoFiller.param('targetProb', False)
266 flavorTaggerInfoFiller.param('trackPointers', False)
267 roe_path.add_module(flavorTaggerInfoFiller) # Add FlavorTag Info filler to roe_path
268 set_output_vars()
269
270 if useGNN:
271 ma.rankByHighest('pi+:inRoe', 'p', numBest=0, allowMultiRank=False,
272 outputVariable='FT_p_rank', overwriteRank=True,
273 path=roe_path)
274 ma.fillParticleListFromDummy('vpho:dummy', path=roe_path)
275
276 if isinstance(identifierGNN, str):
277 roe_path.add_module(
278 'MVAExpert',
279 listNames='vpho:dummy',
280 extraInfoName='qrGNN_raw', # the range of qrGNN_raw is [0,1]
281 identifier=identifierGNN
282 )
283 ma.variableToSignalSideExtraInfo(
284 'vpho:dummy',
285 {'extraInfo(qrGNN_raw)*2-1': 'qrGNN'},
286 path=roe_path
287 )
288 elif isinstance(identifierGNN, list):
289 identifierGNN = list(set(identifierGNN))
290
291 extraInfoNames = [f'qrGNN_{i_id}' for i_id in identifierGNN]
292 roe_path.add_module(
293 'MVAMultipleExperts',
294 listNames='vpho:dummy',
295 extraInfoNames=extraInfoNames,
296 identifiers=identifierGNN
297 )
298
299 extraInfoDict = {}
300 for extraInfoName in extraInfoNames:
301 extraInfoDict[f'extraInfo({extraInfoName})*2-1'] = extraInfoName
302 va.addAlias(extraInfoName, f'extraInfo({extraInfoName})')
303
304 ma.variableToSignalSideExtraInfo(
305 'vpho:dummy',
306 extraInfoDict,
307 path=roe_path
308 )
309
310 path.for_each('RestOfEvent', 'RestOfEvents', roe_path)