Belle II Software  release-05-01-25
trainCurlTaggerClassifier.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3 # Author: Marcel Hohmann (marcel.hohmann@desy.de)
4 
5 import basf2.core
6 from modularAnalysis import inputMdst, tagCurlTracks, process, statistics, register_module
7 from stdCharged import stdPi
8 import sys
9 import os
10 
11 try:
12  input_file_name = str(sys.argv[1])
13 except BaseException:
14  input_file_name = '/hsm/belle2/bdata/MC/release-02-00-01/DB00000411/MC11/prod00005678/'\
15  's00/e0000/4S/r00000/mixed/mdst/sub00/mdst_00000*_prod00005678_task0000000*.root'
16 
17 upload = False # upload to conditions database
18 remove_local_files = False # delete local db and training data
19 tag_name = 'development'
20 
21 # names used by the CurlTagger module:
22 training_file_name = 'CurlTagger_TrainingData_BelleII.root'
23 identifier = 'CurlTagger_FastBDT_BelleII'
24 
25 # --- create training data set and train the classifier ---
26 training_path = basf2.core.Path()
27 
28 inputMdst('default', input_file_name, path=training_path)
29 stdPi('all', path=training_path)
30 tagCurlTracks('pi+:all', train=True, selectorType='mva', path=training_path)
31 
32 progress = register_module('ProgressBar')
33 training_path.add_module(progress)
34 
35 process(training_path, int(2e5))
36 print(statistics)
37 
38 
39 here = os.getcwd()
40 data_base_file = here + "/localdb/database.txt"
41 
42 # upload to global database
43 if upload:
44  os.system("conditionsdb upload {TAGNAME} {DATABASEFILE}".format(TAGNAME=tag_name, DATABASEFILE=data_base_file))
45 
46 if remove_local_files:
47  os.system('rm -r {}'.format(here + '/localdb/'))
48  os.system('rm {}/{}'.format(here, training_file_name))
basf2.core
Definition: core.py:1