4from basf2
import B2INFO, B2FATAL, B2WARNING
5import modularAnalysis
as ma
6from variables
import variables
as va
8 get_available_categories, set_FT_pid_aliases, read_yaml,
9 get_Belle_or_Belle2, set_GFlat_aliases, set_masked_vars,
10 fill_particle_lists, set_output_vars
12from .event_level
import event_level
13from .combiner_level
import combiner_level
19 weightFiles='B2nunubarBGx1',
21 combinerMethods=['TMVA-FBDT'],
24 'IntermediateElectron',
28 'IntermediateKinLepton',
36 maskName='FTDefaultMask',
37 saveCategoriesInfo=True,
38 useOnlyLocalWeightFiles=False,
39 downloadFromDatabaseIfNotFound=False,
40 prefix='MC16rd_light-2501-betelgeuse',
42 identifierGNN='GFlaT_MC16rd_light-2501-betelgeuse_tensorflow',
47 Defines the whole flavor tagging process for each selected Rest of Event (ROE) built in the steering file.
48 The flavor is predicted by Multivariate Methods trained with Variables and MetaVariables which use
49 Tracks, ECL- and KLMClusters from the corresponding RestOfEvent dataobject.
50 This module can be used to sample the training information, to train and/or to test the flavorTagger.
52 @param particleLists The ROEs for flavor tagging are selected from the given particle lists.
53 @param weightFiles Weight files name. Default=
54 ``B2nunubarBGx1`` (official weight files). If the user wants to train the
55 FlavorTagger themselves, the weightfiles name should correspond to the
56 analyzed CP channel in order to avoid confusions. The default name
57 ``B2nunubarBGx1`` corresponds to
58 :math:`B^0_{\\rm sig}\\to \\nu \\overline{\\nu}`.
59 and ``B2JpsiKs_muBGx1`` to
60 :math:`B^0_{\\rm sig}\\to J/\\psi (\\to \\mu^+ \\mu^-) K_s (\\to \\pi^+ \\pi^-)`.
61 BGx1 stands for events simulated with background.
62 @param workingDirectory Path to the directory containing the FlavorTagging/ folder.
63 @param combinerMethods MVAs for the combiner: ``TMVA-FBDT` (default).
64 ``FANN-MLP`` is available only with ``prefix=''`` (MC13 weight files).
65 @param categories Categories used for flavor tagging. By default all are used.
66 @param maskName Gets ROE particles from a specified ROE mask.
67 ``FTDefaultMask`` (default): tentative mask definition that will be created
68 automatically. The definition is as follows:
70 - Track (pion): thetaInCDCAcceptance and dr<1 and abs(dz)<3
71 - ECL-cluster (gamma): thetaInCDCAcceptance and clusterNHits>1.5 and \
72 [[clusterReg==1 and E>0.08] or [clusterReg==2 and E>0.03] or \
73 [clusterReg==3 and E>0.06]] \
74 (Same as gamma:pi0eff30_May2020 and gamma:pi0eff40_May2020)
76 ``all``: all ROE particles are used.
77 Or one can give any mask name defined before calling this function.
78 @param saveCategoriesInfo Sets to save information of individual categories.
79 @param useOnlyLocalWeightFiles [Expert] Uses only locally saved weight files.
80 @param downloadFromDatabaseIfNotFound [Expert] Weight files are downloaded from
81 the conditions database if not available in workingDirectory.
82 @param prefix Prefix of weight files.
83 ``MC16rd_light-2501-betelgeuse`` (default): Weight files trained for MC16rd samples.
84 ``MC15ri_light-2207-bengal_0``: Weight files trained for MC15ri samples.
85 ``''``: Weight files trained for MC13 samples.
86 @param useGNN Use GNN-based Flavor Tagger in addition with FastBDT-based one.
87 Please specify the weight file with the option ``identifierGNN``.
88 [Expert] In the sampler mode,
89 training files for GNN-based Flavor Tagger are produced.
90 @param identifierGNN The name of weight file of the GNN-based Flavor Tagger.
91 [Expert] Multiple identifiers can be given with list(str).
92 @param usePIDNN If True, PID probabilities calculated from PID neural network are used
93 (default is False). Prefix and identifierGNN must be set accordingly.
94 @param path Modules are added to this path
99 exp_type = get_Belle_or_Belle2()
100 available_categories = get_available_categories()
101 config_filepath = basf2.find_file(
'data/analysis/config.yaml')
102 config_params = read_yaml(config_filepath)
103 signal_fraction = config_params[
"signal_fraction"]
104 download_folder = config_params[
"database_io"][
"download_folder"]
109 if (
not isinstance(particleLists, list)):
110 particleLists = [particleLists]
113 if len(categories) != len(set(categories)):
114 dup = [cat
for cat
in set(categories)
if categories.count(cat) > 1]
115 B2WARNING(f
"Flavor Tagger: There are duplicate elements in the given\
116 categories list. The following duplicate elements are\
117 removed: {', '.join(dup)}")
118 categories = list(set(categories))
121 if len(categories) < 2:
122 B2FATAL(
'Flavor Tagger: At least two categories are needed.')
123 B2FATAL(f
"Flavor Tagger: Possible categories are {available_categories.keys()}")
126 for category
in categories:
127 if category
not in available_categories:
128 B2FATAL(
'Flavor Tagger: ' + category +
' is not a valid category name given')
129 B2FATAL(f
"Flavor Tagger: Possible categories are {available_categories.keys()}")
132 if useGNN
and identifierGNN ==
'':
133 B2FATAL(
'Please specify the name of the weight file with ``identifierGNN``')
136 if len(combinerMethods) == 0:
137 B2FATAL(
'Flavor Tagger: Please specify at least one combinerMethods.\
138 The available methods are "TMVA-FBDT" and "FANN-MLP"')
143 for method
in combinerMethods:
144 if method ==
'TMVA-FBDT':
146 elif method ==
'FANN-MLP':
149 B2FATAL(
'Flavor Tagger: Invalid list of combinerMethods. \
150 The available methods are "TMVA-FBDT" and "FANN-MLP"')
153 basf2.find_file(workingDirectory)
155 files_dir = f
"{workingDirectory}/{download_folder}"
156 if downloadFromDatabaseIfNotFound:
157 if not basf2.find_file(files_dir, silent=
True):
158 os.makedirs(files_dir)
162 *** FLAVOR TAGGING ***
163 Working directory is: {files_dir}
168 set_FT_pid_aliases(type=
"MC13", exp_type=exp_type)
170 set_FT_pid_aliases(type=
"Current", exp_type=exp_type)
171 weightFiles = f
"{prefix}_{weightFiles}"
175 set_GFlat_aliases(categories, usePIDNN)
181 trackLevelParticleLists = []
182 eventLevelParticleLists = []
183 variablesCombinerLevel = []
184 categoriesCombination = []
185 categoriesCombinationCode =
'CatCode'
186 for category
in categories:
187 ftCategory = available_categories[category]
189 track_tuple = (ftCategory.particleList, ftCategory.trackName)
190 event_tuple = (ftCategory.particleList, ftCategory.eventName, ftCategory.variableName)
192 if track_tuple
not in trackLevelParticleLists
and category !=
'MaximumPstar':
193 trackLevelParticleLists.append(track_tuple)
195 if event_tuple
not in eventLevelParticleLists:
196 eventLevelParticleLists.append(event_tuple)
197 variablesCombinerLevel.append(ftCategory.variableName)
198 categoriesCombination.append(ftCategory.code)
200 B2FATAL(f
"Flavor Tagger: {category} has been already given")
202 for code
in sorted(categoriesCombination):
203 categoriesCombinationCode = categoriesCombinationCode + f
'{int(code):02}'
206 if maskName ==
'FTDefaultMask':
209 'thetaInCDCAcceptance and dr<1 and abs(dz)<3',
210 'thetaInCDCAcceptance and clusterNHits>1.5 and \
211 [[E>0.08 and clusterReg==1] or [E>0.03 and clusterReg==2] or [E>0.06 and clusterReg==3]]'
213 for name
in particleLists:
214 ma.appendROEMasks(list_name=name, mask_tuples=[FTDefaultMask], path=path)
217 roe_path = basf2.create_path()
218 deadEndPath = basf2.create_path()
222 ma.signalSideParticleListsFilter(particleLists, f
'nROE_Charged({maskName}, 0) > 0', roe_path, deadEndPath)
225 flavorTaggerInfoBuilder = basf2.register_module(
'FlavorTaggerInfoBuilder')
226 path.add_module(flavorTaggerInfoBuilder)
229 fill_particle_lists(maskName, categories, roe_path)
232 weightFiles=weightFiles,
233 categories=categories,
235 useOnlyLocalFlag=useOnlyLocalWeightFiles,
236 downloadFlag=downloadFromDatabaseIfNotFound,
238 signal_fraction=signal_fraction,
243 weightFiles=weightFiles,
244 categories=categories,
245 variablesCombinerLevel=variablesCombinerLevel,
246 categoriesCombinationCode=categoriesCombinationCode,
249 downloadFlag=downloadFromDatabaseIfNotFound,
250 useOnlyLocalFlag=useOnlyLocalWeightFiles,
251 signal_fraction=signal_fraction,
252 filesDirectory=files_dir,
256 flavorTaggerInfoFiller = basf2.register_module(
'FlavorTaggerInfoFiller')
257 flavorTaggerInfoFiller.param(
'trackLevelParticleLists', trackLevelParticleLists)
258 flavorTaggerInfoFiller.param(
'eventLevelParticleLists', eventLevelParticleLists)
259 flavorTaggerInfoFiller.param(
'TMVAfbdt', TMVAfbdt)
260 flavorTaggerInfoFiller.param(
'FANNmlp', FANNmlp)
261 flavorTaggerInfoFiller.param(
'qpCategories', saveCategoriesInfo)
262 flavorTaggerInfoFiller.param(
'istrueCategories', saveCategoriesInfo)
263 flavorTaggerInfoFiller.param(
'targetProb',
False)
264 flavorTaggerInfoFiller.param(
'trackPointers',
False)
265 roe_path.add_module(flavorTaggerInfoFiller)
269 ma.rankByHighest(
'pi+:inRoe',
'p', numBest=0, allowMultiRank=
False,
270 outputVariable=
'FT_p_rank', overwriteRank=
True,
272 ma.fillParticleListFromDummy(
'vpho:dummy', path=roe_path)
274 if isinstance(identifierGNN, str):
277 listNames=
'vpho:dummy',
278 extraInfoName=
'qrGNN_raw',
279 identifier=identifierGNN
281 ma.variableToSignalSideExtraInfo(
283 {
'extraInfo(qrGNN_raw)*2-1':
'qrGNN'},
286 elif isinstance(identifierGNN, list):
287 identifierGNN = list(set(identifierGNN))
289 extraInfoNames = [f
'qrGNN_{i_id}' for i_id
in identifierGNN]
291 'MVAMultipleExperts',
292 listNames=
'vpho:dummy',
293 extraInfoNames=extraInfoNames,
294 identifiers=identifierGNN
298 for extraInfoName
in extraInfoNames:
299 extraInfoDict[f
'extraInfo({extraInfoName})*2-1'] = extraInfoName
300 va.addAlias(extraInfoName, f
'extraInfo({extraInfoName})')
302 ma.variableToSignalSideExtraInfo(
308 path.for_each(
'RestOfEvent',
'RestOfEvents', roe_path)