Belle II Software  release-06-02-00
trainKshortClassifier.py
1 #!/usr/bin/env python3
2 
3 
10 
11 import basf2
12 from basf2 import process, statistics
13 from modularAnalysis import inputMdst, matchMCTruth, variablesToNtuple
14 from stdV0s import stdKshorts
15 import sys
16 import os
17 
18 try:
19  input_file_name = str(sys.argv[1])
20 except BaseException:
21  input_file_name = '/hsm/belle2/bdata/MC/release-01-00-03/DB00000294/MC10/prod00004770/'\
22  's00/e0000/4S/r00000/mixed/mdst/sub00/mdst_00000*_prod00004770_task0000000*.root'
23 
24 try:
25  identifier = sys.argv[2]
26 except BaseException:
27  identifier = 'Kshort_FastBDT.xml' # by default train to xml then upload to localdb
28 
29 
30 tree_name = 'ks_training_variables'
31 training_file_name = 'KshortClassifierTrainingData.root'
32 
33 my_variables = ['SigM',
34  'formula( E / E_uncertainty )',
35  'formula( flightTime / flightTimeErr)',
36  'cosAngleBetweenMomentumAndVertexVector',
37  'min(abs(daughter(0, d0)),abs(daughter(1, d0)))',
38  'formula(daughter(0, pionID) + daughter(1, pionID))'
39  ]
40 
41 target_variable = 'isSignal'
42 
43 
44 # --- create training data set ---
45 training_path = basf2.core.Path()
46 inputMdst('default', input_file_name, path=training_path)
47 stdKshorts(path=training_path)
48 matchMCTruth('K_S0:merged', path=training_path)
49 
50 variablesToNtuple('K_S0:merged',
51  my_variables + [target_variable],
52  tree_name,
53  training_file_name,
54  path=training_path
55  )
56 
57 process(training_path, int(2e5))
58 print(statistics)
59 
60 
61 # --- train variables ---
62 training_string = 'basf2_mva_teacher --datafiles {data_files} --treename {tree_name}'\
63  ' --identifier {identifier} --variables {variables} --target_variable'\
64  ' {target_variable} --method FastBDT --nTrees 400 --nCutLevels 8 --nLevels 4'.format(
65  data_files=training_file_name,
66  tree_name=tree_name,
67  identifier=identifier,
68  variables=''.join([' "%s" ' % var for var in my_variables]),
69  target_variable=target_variable)
70 
71 os.system(training_string)
72 
73 ex_b = 0 # experiment begin, 0 for all of them
74 ex_e = -1 # experiment end, -1 for all of them
75 run_b = 0 # run begin, 0 for all
76 run_e = -1 # run end, -1 for all of them
77 tag_name = "development" # global tag name
78 
79 upload = False # upload to conditions database
80 remove_local_files = False # delete local db and training data
81 
82 # upload to local database from xml file
83 os.system('basf2_mva_upload --identifier {identifier} --db_identifier {identifier_db}'
84  ' --begin_experiment {ex_b} --end_experiment {ex_e} --begin_run {run_b} --end_run {run_e}'.format(
85  identifier=identifier,
86  identifier_db=identifier.split(".xml")[0],
87  ex_b=ex_b,
88  ex_e=ex_e,
89  run_b=run_b,
90  run_e=run_e))
91 
92 here = os.getcwd()
93 data_base_file = here + "/localdb/database.txt"
94 
95 # upload to global database
96 if upload:
97  os.system(f"b2conditionsdb-upload {tag_name} {data_base_file}")
98 
99 if remove_local_files:
100  os.system('rm -r {}'.format(here + '/localdb/'))
101  os.system(f'rm {here}/{training_file_name}')
102  os.system(f'rm {here}/{identifier}')