Belle II Software  release-05-01-25
steering_training_data.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3 
4 
13 
14 from dft.DeepFlavorTagger import *
15 
16 
17 def create_train_data(
18  working_dir,
19  file_names,
20  identifier,
21  variable_list,
22  environmentType='MC5',
23  target='qrCombined',
24  overwrite=False,
25  max_events=0,
26  mode='sampler',
27  *args,
28  **kwargs):
29  main = create_path()
30 
31  if not os.path.exists(working_dir) and working_dir is not '':
32  os.makedirs(working_dir)
33 
34  inputMdstList(environmentType, filelist=file_names, path=main)
35 
36  findMCDecay('B0:sig', 'B0 -> nu_tau anti-nu_tau', writeOut=True, path=main)
37  matchMCTruth('B0:sig', main)
38  applyCuts('B0:sig', 'isSignal > 0.5', path=main)
39 
40  buildRestOfEvent('B0:sig', path=main)
41 
42  DeepFlavorTagger('B0:sig', mode, working_dir, identifier, variable_list, target=target, overwrite=overwrite,
43  path=main, *args, **kwargs)
44 
45  main.add_module('ProgressBar')
46 
47  process(main, max_events)
48  print(statistics)
49 
50 
51 def test_expert(working_dir, file_names, identifier, output_variable='networkOutput', environmentType='MC5',
52  max_events=0):
53  main = create_path()
54 
55  inputMdstList(environmentType, file_names, path=main)
56 
57  findMCDecay('B0:sig', 'B0 -> nu_tau anti-nu_tau', writeOut=True, path=main)
58  matchMCTruth('B0:sig', main)
59  applyCuts('B0:sig', 'isSignal > 0.5', path=main)
60 
61  buildRestOfEvent('B0:sig', path=main)
62 
63  # main.add_module('PrintCollections')
64  DeepFlavorTagger('B0:sig', 'expert', working_dir, identifier, path=main)
65 
66  # define output variable
67  output_variable_name = ''.join('extraInfo(', output_variable, ')')
68 
69  variablesToNtuple('B0:sig', ['extraInfo(qrCombined)', output_variable_name],
70  filename=os.path.join(working_dir, identifier + '_test_output.root'),
71  path=main)
72 
73  main.add_module('ProgressBar')
74 
75  process(main, max_events)
76  print(statistics)
77 
78 
79 def test_expert_jpsi(working_dir, file_names, prefix, environmentType='MC5', max_events=0):
80  main = create_path()
81 
82  inputMdstList(environmentType, file_names, path=main)
83 
84  fillParticleList('pi+:highPID', 'piid >= .1', path=main)
85  fillParticleList('mu+:highPID', 'muid >= .1', path=main)
86 
87  # reconstruct Ks -> pi+ pi- decay
88  # keep only candidates with dM<0.25
89  reconstructDecay('K_S0:pipi -> pi+:highPID pi-:highPID', '.25 <= M <= .75', path=main)
90  # fit K_S0 Vertex
91 
92  raveFit('K_S0:pipi', 0., path=main, silence_warning=True)
93 
94  # reconstruct J/psi -> mu+ mu- decay and fit vertex
95  reconstructDecay('J/psi:mumu -> mu+:highPID mu-:highPID', '3.0 <= M <= 3.2 ', path=main)
96 
97  # applyCuts('J/psi:mumu', '3.07 < M < 3.11', path=main)
98  applyCuts('J/psi:mumu', '', path=main)
99  raveFit('J/psi:mumu', 0., fit_type='massvertex', path=main, silence_warning=True)
100 
101  # reconstruct B0 -> J/psi Ks decay
102  reconstructDecay('B0:jpsiks -> J/psi:mumu K_S0:pipi', '5.2 <= M <= 5.4', path=main)
103 
104  # Fit the B0 Vertex
105  raveFit('B0:jpsiks', 0., 'vertex', 'B0 -> [J/psi -> ^mu+ ^mu-] K_S0', '', path=main, silence_warning=True)
106 
107  # perform MC matching (MC truth asociation). Always before TagV
108  matchMCTruth('B0:jpsiks', path=main)
109 
110  # build the rest of the event associated to the B0
111  buildRestOfEvent('B0:jpsiks', path=main)
112  applyCuts('B0:jpsiks', 'isSignal > 0.5', path=main)
113  # main.add_module('PrintCollections')
114 
115  DeepFlavorTagger('B0:jpsiks', 'Expert', working_dir, prefix, transform_to_probability=True, path=main)
116  variablesToNtuple('B0:jpsiks', ['extraInfo(qrCombined)', 'extraInfo(qrMC)', 'extraInfo(B0Probability)',
117  'extraInfo(BOProbabilityMC)'],
118  filename=os.path.join(working_dir, 'test_output.root'), path=main)
119 
120  main.add_module('ProgressBar')
121 
122  process(main, max_events)
123  print(statistics)
variablesToNtuple
Definition: variablesToNtuple.py:1
dft.DeepFlavorTagger
Definition: DeepFlavorTagger.py:1