Belle II Software  release-08-01-10
B2A712-DeepContinuumSuppression_MVATrain.py
1 #!/usr/bin/env python3
2 
3 
10 
11 
28 
29 import basf2 as b2
30 import basf2_mva
31 import subprocess
32 import json
33 import os
34 
35 
36 def choose_input_features(use_vertex_features=True, use_charge_and_ROE_features=False, use_continuum_features=1):
37  """
38  Function to return all names of input features.
39  :param use_vertex_features: If Vertex info should be included.
40  :param use_charge_and_ROE_features: If charge and ROE should be included as extra features(information already
41  included in group structure). This option is only nevessary when using Relation Layers.
42  :param use_continuum_features: Use old Continuum Features (0: No, 1: Yes, 2: Use only the old features)
43  :return: Array of feature names
44  """
45  contVar = [
46  'R2',
47  'thrustBm',
48  'thrustOm',
49  'cosTBTO',
50  'cosTBz',
51  'KSFWVariables(et)',
52  'KSFWVariables(mm2)',
53  'KSFWVariables(hso00)',
54  'KSFWVariables(hso02)',
55  'KSFWVariables(hso04)',
56  'KSFWVariables(hso10)',
57  'KSFWVariables(hso12)',
58  'KSFWVariables(hso14)',
59  'KSFWVariables(hso20)',
60  'KSFWVariables(hso22)',
61  'KSFWVariables(hso24)',
62  'KSFWVariables(hoo0)',
63  'KSFWVariables(hoo1)',
64  'KSFWVariables(hoo2)',
65  'KSFWVariables(hoo3)',
66  'KSFWVariables(hoo4)',
67  'CleoConeCS(1)',
68  'CleoConeCS(2)',
69  'CleoConeCS(3)',
70  'CleoConeCS(4)',
71  'CleoConeCS(5)',
72  'CleoConeCS(6)',
73  'CleoConeCS(7)',
74  'CleoConeCS(8)',
75  'CleoConeCS(9)']
76 
77  if use_continuum_features == 2:
78  return contVar
79 
80  basic_variables = ['p', 'phi', 'cosTheta', 'pErr', 'phiErr', 'cosThetaErr']
81  vertex_variables = ['distance', 'dphi', 'dcosTheta']
82 
83  cluster_specific_variables = ['clusterNHits', 'clusterTiming', 'clusterE9E25', 'clusterReg']
84  track_specific_variables = ['kaonID', 'electronID', 'muonID', 'protonID', 'pValue', 'nCDCHits']
85 
86  if use_charge_and_ROE_features:
87  cluster_specific_variables += ['isInRestOfEvent']
88  track_specific_variables += ['isInRestOfEvent', 'charge']
89 
90  cluster_specific_variables += ['thrustsig' + var for var in basic_variables]
91  track_specific_variables += ['thrustsig' + var for var in basic_variables]
92 
93  if use_vertex_features:
94  track_specific_variables += ['thrustsig' + var for var in vertex_variables]
95 
96  cluster_lists = ['Csig', 'Croe']
97  track_lists = ['TPsig', 'TMsig', 'TProe', 'TMroe']
98 
99  variables = []
100  for plist in track_lists:
101  for rank in range(5):
102  for var in track_specific_variables:
103  variables.append(f'{var}_{plist}{rank}')
104 
105  for plist in cluster_lists:
106  for rank in range(10):
107  for var in cluster_specific_variables:
108  variables.append(f'{var}_{plist}{rank}')
109 
110  if use_continuum_features:
111  variables += contVar
112 
113  return variables
114 
115 
116 if __name__ == "__main__":
117 
118  if not os.getenv('BELLE2_EXAMPLES_DATA_DIR'):
119  b2.B2FATAL("You need the example data installed. Run `b2install-data example` in terminal for it.")
120 
121  path = os.getenv('BELLE2_EXAMPLES_DATA_DIR')+'/mva/'
122 
123  train_data = path + 'DNN_train.root'
124  test_data = path + 'DNN_test.root'
125 
126  general_options = basf2_mva.GeneralOptions()
127  general_options.m_datafiles = basf2_mva.vector(train_data)
128  general_options.m_treename = "tree"
129  general_options.m_identifier = "Deep_Feed_Forward.xml"
130  general_options.m_variables = basf2_mva.vector(*choose_input_features(True, False, 1))
131  general_options.m_spectators = basf2_mva.vector('Mbc', 'DeltaZ')
132  general_options.m_target_variable = "isNotContinuumEvent"
133 
134  specific_options = basf2_mva.PythonOptions()
135  specific_options.m_framework = "keras"
136  specific_options.m_steering_file = 'analysis/examples/mva/B2A714-DeepContinuumSuppression_MVAModel.py'
137  specific_options.m_training_fraction = 0.9
138 
139  # These options are custom made in B2A714. You can also add your own parameters.
140  # Try different options and compare them by handing multiple weightfiles in basf2_mva_evaluation.py
141  keras_dic = {
142  # If Relation layer should be used instead of Feed Forward.
143  # Only works with choose_input_features(True, True, 1)
144  'use_relation_layers': False,
145  # The following options are for using Adversaries. To disable them leave lambda to zero.
146  # See mva/examples/keras/adversary_network.py for details
147  'lambda': 0, # Use 500 as starting point to try the Adversaries out
148  'number_bins': 10,
149  'adversary_steps': 5}
150  specific_options.m_config = json.dumps(keras_dic)
151 
152  # Train a MVA method and store the weightfile (Deep_Feed_Forward.xml) locally.
153  basf2_mva.teacher(general_options, specific_options)
154 
155  # Evaluate training.
156  subprocess.call('basf2_mva_evaluate.py '
157  ' -train ' + train_data +
158  ' -data ' + test_data +
159  ' -id ' + 'Deep_Feed_Forward.xml' +
160  ' --output qqbarSuppressionEvaluation.pdf' +
161  ' --fillnan',
162  shell=True
163  )
164 
165  # If you're only interested in the network output distribution, then
166  # comment these in to apply the trained methods on train and test sample
167  #
168  # basf2_mva.expert(basf2_mva.vector('Deep_Feed_Forward.xml'),
169  # basf2_mva.vector(train_data), 'tree', 'MVAExpert_train.root')
170  # basf2_mva.expert(basf2_mva.vector('Deep_Feed_Forward.xml'),
171  # basf2_mva.vector(test_data), 'tree', 'MVAExpert_test.root')