Belle II Software development
flavorTagger.py
1# This is the main flavortagger function
2
3import basf2
4from basf2 import B2INFO, B2FATAL, B2WARNING
5import modularAnalysis as ma
6from variables import variables as va
7from flavorTagger.utils import (
8 get_available_categories, set_FT_pid_aliases, read_yaml,
9 get_Belle_or_Belle2, set_GFlat_aliases, set_masked_vars,
10 fill_particle_lists, set_output_vars
11)
12from .event_level import event_level
13from .combiner_level import combiner_level
14import os
15
16
17def flavorTagger(
18 particleLists=None,
19 weightFiles='B2nunubarBGx1',
20 workingDirectory='.',
21 combinerMethods=['TMVA-FBDT'],
22 categories=[
23 'Electron',
24 'IntermediateElectron',
25 'Muon',
26 'IntermediateMuon',
27 'KinLepton',
28 'IntermediateKinLepton',
29 'Kaon',
30 'SlowPion',
31 'FastHadron',
32 'Lambda',
33 'FSC',
34 'MaximumPstar',
35 'KaonPion'],
36 maskName='FTDefaultMask',
37 saveCategoriesInfo=True,
38 useOnlyLocalWeightFiles=False,
39 downloadFromDatabaseIfNotFound=False,
40 prefix='MC16rd_light-2501-betelgeuse',
41 useGNN=True,
42 identifierGNN='GFlaT_MC16rd_light-2501-betelgeuse_tensorflow',
43 usePIDNN=False,
44 path=None,
45):
46 """
47 Defines the whole flavor tagging process for each selected Rest of Event (ROE) built in the steering file.
48 The flavor is predicted by Multivariate Methods trained with Variables and MetaVariables which use
49 Tracks, ECL- and KLMClusters from the corresponding RestOfEvent dataobject.
50 This module can be used to sample the training information, to train and/or to test the flavorTagger.
51
52 @param particleLists The ROEs for flavor tagging are selected from the given particle lists.
53 @param weightFiles Weight files name. Default=
54 ``B2nunubarBGx1`` (official weight files). If the user wants to train the
55 FlavorTagger themselves, the weightfiles name should correspond to the
56 analyzed CP channel in order to avoid confusions. The default name
57 ``B2nunubarBGx1`` corresponds to
58 :math:`B^0_{\\rm sig}\\to \\nu \\overline{\\nu}`.
59 and ``B2JpsiKs_muBGx1`` to
60 :math:`B^0_{\\rm sig}\\to J/\\psi (\\to \\mu^+ \\mu^-) K_s (\\to \\pi^+ \\pi^-)`.
61 BGx1 stands for events simulated with background.
62 @param workingDirectory Path to the directory containing the FlavorTagging/ folder.
63 @param combinerMethods MVAs for the combiner: ``TMVA-FBDT` (default).
64 ``FANN-MLP`` is available only with ``prefix=''`` (MC13 weight files).
65 @param categories Categories used for flavor tagging. By default all are used.
66 @param maskName Gets ROE particles from a specified ROE mask.
67 ``FTDefaultMask`` (default): tentative mask definition that will be created
68 automatically. The definition is as follows:
69
70 - Track (pion): thetaInCDCAcceptance and dr<1 and abs(dz)<3
71 - ECL-cluster (gamma): thetaInCDCAcceptance and clusterNHits>1.5 and \
72 [[clusterReg==1 and E>0.08] or [clusterReg==2 and E>0.03] or \
73 [clusterReg==3 and E>0.06]] \
74 (Same as gamma:pi0eff30_May2020 and gamma:pi0eff40_May2020)
75
76 ``all``: all ROE particles are used.
77 Or one can give any mask name defined before calling this function.
78 @param saveCategoriesInfo Sets to save information of individual categories.
79 @param useOnlyLocalWeightFiles [Expert] Uses only locally saved weight files.
80 @param downloadFromDatabaseIfNotFound [Expert] Weight files are downloaded from
81 the conditions database if not available in workingDirectory.
82 @param prefix Prefix of weight files.
83 ``MC16rd_light-2501-betelgeuse`` (default): Weight files trained for MC16rd samples.
84 ``MC15ri_light-2207-bengal_0``: Weight files trained for MC15ri samples.
85 ``''``: Weight files trained for MC13 samples.
86 @param useGNN Use GNN-based Flavor Tagger in addition with FastBDT-based one.
87 Please specify the weight file with the option ``identifierGNN``.
88 [Expert] In the sampler mode,
89 training files for GNN-based Flavor Tagger are produced.
90 @param identifierGNN The name of weight file of the GNN-based Flavor Tagger.
91 [Expert] Multiple identifiers can be given with list(str).
92 @param usePIDNN If True, PID probabilities calculated from PID neural network are used
93 (default is False). Prefix and identifierGNN must be set accordingly.
94 @param path Modules are added to this path
95
96 """
97
98 # set common config
99 exp_type = get_Belle_or_Belle2()
100 available_categories = get_available_categories()
101 config_filepath = basf2.find_file('data/analysis/config.yaml')
102 config_params = read_yaml(config_filepath)
103 signal_fraction = config_params["signal_fraction"]
104 download_folder = config_params["database_io"]["download_folder"]
105
106 # sanitize and cross check inputs
107
108 # force the particle list to be a list
109 if (not isinstance(particleLists, list)):
110 particleLists = [particleLists]
111
112 # ensure unique categories in list
113 if len(categories) != len(set(categories)):
114 dup = [cat for cat in set(categories) if categories.count(cat) > 1]
115 B2WARNING(f"Flavor Tagger: There are duplicate elements in the given\
116 categories list. The following duplicate elements are\
117 removed: {', '.join(dup)}")
118 categories = list(set(categories))
119
120 # ensure minimum category number
121 if len(categories) < 2:
122 B2FATAL('Flavor Tagger: At least two categories are needed.')
123 B2FATAL(f"Flavor Tagger: Possible categories are {available_categories.keys()}")
124
125 # ensure legitimate categories
126 for category in categories:
127 if category not in available_categories:
128 B2FATAL('Flavor Tagger: ' + category + ' is not a valid category name given')
129 B2FATAL(f"Flavor Tagger: Possible categories are {available_categories.keys()}")
130
131 # ensure correct GNN config
132 if useGNN and identifierGNN == '':
133 B2FATAL('Please specify the name of the weight file with ``identifierGNN``')
134
135 # ensure correct combiner method config
136 if len(combinerMethods) == 0:
137 B2FATAL('Flavor Tagger: Please specify at least one combinerMethods.\
138 The available methods are "TMVA-FBDT" and "FANN-MLP"')
139
140 FANNmlp = False
141 TMVAfbdt = False
142
143 for method in combinerMethods:
144 if method == 'TMVA-FBDT':
145 TMVAfbdt = True
146 elif method == 'FANN-MLP':
147 FANNmlp = True
148 else:
149 B2FATAL('Flavor Tagger: Invalid list of combinerMethods. \
150 The available methods are "TMVA-FBDT" and "FANN-MLP"')
151
152 # check if working directory exists for download
153 basf2.find_file(workingDirectory)
154
155 files_dir = f"{workingDirectory}/{download_folder}"
156 if downloadFromDatabaseIfNotFound:
157 if not basf2.find_file(files_dir, silent=True):
158 os.makedirs(files_dir)
159
160 # verbose
161 B2INFO(f"""
162 *** FLAVOR TAGGING ***
163 Working directory is: {files_dir}
164 """)
165
166 # setup FT pid alias
167 if prefix == '':
168 set_FT_pid_aliases(type="MC13", exp_type=exp_type)
169 else:
170 set_FT_pid_aliases(type="Current", exp_type=exp_type)
171 weightFiles = f"{prefix}_{weightFiles}"
172
173 # set GNN aliases
174 if useGNN:
175 set_GFlat_aliases(categories, usePIDNN)
176
177 # set input masked vars
178 set_masked_vars()
179
180 # Create configuration lists and code-name for given category's list
181 trackLevelParticleLists = []
182 eventLevelParticleLists = []
183 variablesCombinerLevel = []
184 categoriesCombination = []
185 categoriesCombinationCode = 'CatCode'
186 for category in categories:
187 ftCategory = available_categories[category]
188
189 track_tuple = (ftCategory.particleList, ftCategory.trackName)
190 event_tuple = (ftCategory.particleList, ftCategory.eventName, ftCategory.variableName)
191
192 if track_tuple not in trackLevelParticleLists and category != 'MaximumPstar':
193 trackLevelParticleLists.append(track_tuple)
194
195 if event_tuple not in eventLevelParticleLists:
196 eventLevelParticleLists.append(event_tuple)
197 variablesCombinerLevel.append(ftCategory.variableName)
198 categoriesCombination.append(ftCategory.code)
199 else:
200 B2FATAL(f"Flavor Tagger: {category} has been already given")
201
202 for code in sorted(categoriesCombination):
203 categoriesCombinationCode = categoriesCombinationCode + f'{int(code):02}'
204
205 # Create default ROE-mask
206 if maskName == 'FTDefaultMask':
207 FTDefaultMask = (
208 'FTDefaultMask',
209 'thetaInCDCAcceptance and dr<1 and abs(dz)<3',
210 'thetaInCDCAcceptance and clusterNHits>1.5 and \
211 [[E>0.08 and clusterReg==1] or [E>0.03 and clusterReg==2] or [E>0.06 and clusterReg==3]]'
212 )
213 for name in particleLists:
214 ma.appendROEMasks(list_name=name, mask_tuples=[FTDefaultMask], path=path)
215
216 # Start ROE-routine
217 roe_path = basf2.create_path()
218 deadEndPath = basf2.create_path()
219
220 # If trigger returns 1 jump into empty path skipping further modules in roe_path
221 # run filter with no cut first to get rid of ROEs that are missing the mask of the signal particle
222 ma.signalSideParticleListsFilter(particleLists, f'nROE_Charged({maskName}, 0) > 0', roe_path, deadEndPath)
223
224 # Initialization of flavorTaggerInfo dataObject needs to be done in the main path
225 flavorTaggerInfoBuilder = basf2.register_module('FlavorTaggerInfoBuilder')
226 path.add_module(flavorTaggerInfoBuilder)
227
228 # fill particle lists
229 fill_particle_lists(maskName, categories, roe_path)
230
231 if event_level(
232 weightFiles=weightFiles,
233 categories=categories,
234 files_dir=files_dir,
235 useOnlyLocalFlag=useOnlyLocalWeightFiles,
236 downloadFlag=downloadFromDatabaseIfNotFound,
237 exp_type=exp_type,
238 signal_fraction=signal_fraction,
239 path=roe_path
240 ):
241
243 weightFiles=weightFiles,
244 categories=categories,
245 variablesCombinerLevel=variablesCombinerLevel,
246 categoriesCombinationCode=categoriesCombinationCode,
247 TMVAfbdt=TMVAfbdt,
248 FANNmlp=FANNmlp,
249 downloadFlag=downloadFromDatabaseIfNotFound,
250 useOnlyLocalFlag=useOnlyLocalWeightFiles,
251 signal_fraction=signal_fraction,
252 filesDirectory=files_dir,
253 path=roe_path
254 )
255
256 flavorTaggerInfoFiller = basf2.register_module('FlavorTaggerInfoFiller')
257 flavorTaggerInfoFiller.param('trackLevelParticleLists', trackLevelParticleLists)
258 flavorTaggerInfoFiller.param('eventLevelParticleLists', eventLevelParticleLists)
259 flavorTaggerInfoFiller.param('TMVAfbdt', TMVAfbdt)
260 flavorTaggerInfoFiller.param('FANNmlp', FANNmlp)
261 flavorTaggerInfoFiller.param('qpCategories', saveCategoriesInfo)
262 flavorTaggerInfoFiller.param('istrueCategories', saveCategoriesInfo)
263 flavorTaggerInfoFiller.param('targetProb', False)
264 flavorTaggerInfoFiller.param('trackPointers', False)
265 roe_path.add_module(flavorTaggerInfoFiller) # Add FlavorTag Info filler to roe_path
266 set_output_vars()
267
268 if useGNN:
269 ma.rankByHighest('pi+:inRoe', 'p', numBest=0, allowMultiRank=False,
270 outputVariable='FT_p_rank', overwriteRank=True,
271 path=roe_path)
272 ma.fillParticleListFromDummy('vpho:dummy', path=roe_path)
273
274 if isinstance(identifierGNN, str):
275 roe_path.add_module(
276 'MVAExpert',
277 listNames='vpho:dummy',
278 extraInfoName='qrGNN_raw', # the range of qrGNN_raw is [0,1]
279 identifier=identifierGNN
280 )
281 ma.variableToSignalSideExtraInfo(
282 'vpho:dummy',
283 {'extraInfo(qrGNN_raw)*2-1': 'qrGNN'},
284 path=roe_path
285 )
286 elif isinstance(identifierGNN, list):
287 identifierGNN = list(set(identifierGNN))
288
289 extraInfoNames = [f'qrGNN_{i_id}' for i_id in identifierGNN]
290 roe_path.add_module(
291 'MVAMultipleExperts',
292 listNames='vpho:dummy',
293 extraInfoNames=extraInfoNames,
294 identifiers=identifierGNN
295 )
296
297 extraInfoDict = {}
298 for extraInfoName in extraInfoNames:
299 extraInfoDict[f'extraInfo({extraInfoName})*2-1'] = extraInfoName
300 va.addAlias(extraInfoName, f'extraInfo({extraInfoName})')
301
302 ma.variableToSignalSideExtraInfo(
303 'vpho:dummy',
304 extraInfoDict,
305 path=roe_path
306 )
307
308 path.for_each('RestOfEvent', 'RestOfEvents', roe_path)