Belle II Software development
B2A712-DeepContinuumSuppression_MVATrain.py
1#!/usr/bin/env python3
2
3
10
11
28
29import basf2 as b2
30import basf2_mva
31import subprocess
32import json
33import os
34
35
36def choose_input_features(use_vertex_features=True, use_charge_and_ROE_features=False, use_continuum_features=1):
37 """
38 Function to return all names of input features.
39 :param use_vertex_features: If Vertex info should be included.
40 :param use_charge_and_ROE_features: If charge and ROE should be included as extra features(information already
41 included in group structure). This option is only nevessary when using Relation Layers.
42 :param use_continuum_features: Use old Continuum Features (0: No, 1: Yes, 2: Use only the old features)
43 :return: Array of feature names
44 """
45 contVar = [
46 'R2',
47 'thrustBm',
48 'thrustOm',
49 'cosTBTO',
50 'cosTBz',
51 'KSFWVariables(et)',
52 'KSFWVariables(mm2)',
53 'KSFWVariables(hso00)',
54 'KSFWVariables(hso02)',
55 'KSFWVariables(hso04)',
56 'KSFWVariables(hso10)',
57 'KSFWVariables(hso12)',
58 'KSFWVariables(hso14)',
59 'KSFWVariables(hso20)',
60 'KSFWVariables(hso22)',
61 'KSFWVariables(hso24)',
62 'KSFWVariables(hoo0)',
63 'KSFWVariables(hoo1)',
64 'KSFWVariables(hoo2)',
65 'KSFWVariables(hoo3)',
66 'KSFWVariables(hoo4)',
67 'CleoConeCS(1)',
68 'CleoConeCS(2)',
69 'CleoConeCS(3)',
70 'CleoConeCS(4)',
71 'CleoConeCS(5)',
72 'CleoConeCS(6)',
73 'CleoConeCS(7)',
74 'CleoConeCS(8)',
75 'CleoConeCS(9)']
76
77 if use_continuum_features == 2:
78 return contVar
79
80 basic_variables = ['p', 'phi', 'cosTheta', 'pErr', 'phiErr', 'cosThetaErr']
81 vertex_variables = ['distance', 'dphi', 'dcosTheta']
82
83 cluster_specific_variables = ['clusterNHits', 'clusterTiming', 'clusterE9E25', 'clusterReg']
84 track_specific_variables = ['kaonID', 'electronID', 'muonID', 'protonID', 'pValue', 'nCDCHits']
85
86 if use_charge_and_ROE_features:
87 cluster_specific_variables += ['isInRestOfEvent']
88 track_specific_variables += ['isInRestOfEvent', 'charge']
89
90 cluster_specific_variables += ['thrustsig' + var for var in basic_variables]
91 track_specific_variables += ['thrustsig' + var for var in basic_variables]
92
93 if use_vertex_features:
94 track_specific_variables += ['thrustsig' + var for var in vertex_variables]
95
96 cluster_lists = ['Csig', 'Croe']
97 track_lists = ['TPsig', 'TMsig', 'TProe', 'TMroe']
98
99 variables = []
100 for plist in track_lists:
101 for rank in range(5):
102 for var in track_specific_variables:
103 variables.append(f'{var}_{plist}{rank}')
104
105 for plist in cluster_lists:
106 for rank in range(10):
107 for var in cluster_specific_variables:
108 variables.append(f'{var}_{plist}{rank}')
109
110 if use_continuum_features:
111 variables += contVar
112
113 return variables
114
115
116if __name__ == "__main__":
117
118 if not os.getenv('BELLE2_EXAMPLES_DATA_DIR'):
119 b2.B2FATAL("You need the example data installed. Run `b2install-data example` in terminal for it.")
120
121 path = os.getenv('BELLE2_EXAMPLES_DATA_DIR')+'/mva/'
122
123 train_data = path + 'DNN_train.root'
124 test_data = path + 'DNN_test.root'
125
126 general_options = basf2_mva.GeneralOptions()
127 general_options.m_datafiles = basf2_mva.vector(train_data)
128 general_options.m_treename = "tree"
129 general_options.m_identifier = "Deep_Feed_Forward.xml"
130 general_options.m_variables = basf2_mva.vector(*choose_input_features(True, False, 1))
131 general_options.m_spectators = basf2_mva.vector('Mbc', 'DeltaZ')
132 general_options.m_target_variable = "isNotContinuumEvent"
133
134 specific_options = basf2_mva.PythonOptions()
135 specific_options.m_framework = "keras"
136 specific_options.m_steering_file = 'analysis/examples/mva/B2A714-DeepContinuumSuppression_MVAModel.py'
137 specific_options.m_training_fraction = 0.9
138
139 # These options are custom made in B2A714. You can also add your own parameters.
140 # Try different options and compare them by handing multiple weightfiles in basf2_mva_evaluation.py
141 keras_dic = {
142 # If Relation layer should be used instead of Feed Forward.
143 # Only works with choose_input_features(True, True, 1)
144 'use_relation_layers': False,
145 # The following options are for using Adversaries. To disable them leave lambda to zero.
146 # See mva/examples/keras/adversary_network.py for details
147 'lambda': 0, # Use 500 as starting point to try the Adversaries out
148 'number_bins': 10,
149 'adversary_steps': 5}
150 specific_options.m_config = json.dumps(keras_dic)
151
152 # Train a MVA method and store the weightfile (Deep_Feed_Forward.xml) locally.
153 basf2_mva.teacher(general_options, specific_options)
154
155 # Evaluate training.
156 subprocess.call('basf2_mva_evaluate.py '
157 ' -train ' + train_data +
158 ' -data ' + test_data +
159 ' -id ' + 'Deep_Feed_Forward.xml' +
160 ' --output qqbarSuppressionEvaluation.pdf' +
161 ' --fillnan',
162 shell=True
163 )
164
165 # If you're only interested in the network output distribution, then
166 # comment these in to apply the trained methods on train and test sample
167 #
168 # basf2_mva.expert(basf2_mva.vector('Deep_Feed_Forward.xml'),
169 # basf2_mva.vector(train_data), 'tree', 'MVAExpert_train.root')
170 # basf2_mva.expert(basf2_mva.vector('Deep_Feed_Forward.xml'),
171 # basf2_mva.vector(test_data), 'tree', 'MVAExpert_test.root')