36 def choose_input_features(use_vertex_features=True, use_charge_and_ROE_features=False, use_continuum_features=1):
38 Function to return all names of input features.
39 :param use_vertex_features: If Vertex info should be included.
40 :param use_charge_and_ROE_features: If charge and ROE should be included as extra features(information already
41 included in group structure). This option is only nevessary when using Relation Layers.
42 :param use_continuum_features: Use old Continuum Features (0: No, 1: Yes, 2: Use only the old features)
43 :return: Array of feature names
53 'KSFWVariables(hso00)',
54 'KSFWVariables(hso02)',
55 'KSFWVariables(hso04)',
56 'KSFWVariables(hso10)',
57 'KSFWVariables(hso12)',
58 'KSFWVariables(hso14)',
59 'KSFWVariables(hso20)',
60 'KSFWVariables(hso22)',
61 'KSFWVariables(hso24)',
62 'KSFWVariables(hoo0)',
63 'KSFWVariables(hoo1)',
64 'KSFWVariables(hoo2)',
65 'KSFWVariables(hoo3)',
66 'KSFWVariables(hoo4)',
77 if use_continuum_features == 2:
80 basic_variables = [
'p',
'phi',
'cosTheta',
'pErr',
'phiErr',
'cosThetaErr']
81 vertex_variables = [
'distance',
'dphi',
'dcosTheta']
83 cluster_specific_variables = [
'clusterNHits',
'clusterTiming',
'clusterE9E25',
'clusterReg']
84 track_specific_variables = [
'kaonID',
'electronID',
'muonID',
'protonID',
'pValue',
'nCDCHits']
86 if use_charge_and_ROE_features:
87 cluster_specific_variables += [
'isInRestOfEvent']
88 track_specific_variables += [
'isInRestOfEvent',
'charge']
90 cluster_specific_variables += [
'thrustsig' + var
for var
in basic_variables]
91 track_specific_variables += [
'thrustsig' + var
for var
in basic_variables]
93 if use_vertex_features:
94 track_specific_variables += [
'thrustsig' + var
for var
in vertex_variables]
96 cluster_lists = [
'Csig',
'Croe']
97 track_lists = [
'TPsig',
'TMsig',
'TProe',
'TMroe']
100 for plist
in track_lists:
101 for rank
in range(5):
102 for var
in track_specific_variables:
103 variables.append(f
'{var}_{plist}{rank}')
105 for plist
in cluster_lists:
106 for rank
in range(10):
107 for var
in cluster_specific_variables:
108 variables.append(f
'{var}_{plist}{rank}')
110 if use_continuum_features:
116 if __name__ ==
"__main__":
118 if not os.getenv(
'BELLE2_EXAMPLES_DATA_DIR'):
119 b2.B2FATAL(
"You need the example data installed. Run `b2install-data example` in terminal for it.")
121 path = os.getenv(
'BELLE2_EXAMPLES_DATA_DIR')+
'/mva/'
123 train_data = path +
'DNN_train.root'
124 test_data = path +
'DNN_test.root'
126 general_options = basf2_mva.GeneralOptions()
127 general_options.m_datafiles = basf2_mva.vector(train_data)
128 general_options.m_treename =
"tree"
129 general_options.m_identifier =
"Deep_Feed_Forward.xml"
130 general_options.m_variables = basf2_mva.vector(*choose_input_features(
True,
False, 1))
131 general_options.m_spectators = basf2_mva.vector(
'Mbc',
'DeltaZ')
132 general_options.m_target_variable =
"isNotContinuumEvent"
134 specific_options = basf2_mva.PythonOptions()
135 specific_options.m_framework =
"keras"
136 specific_options.m_steering_file =
'analysis/examples/mva/B2A714-DeepContinuumSuppression_MVAModel.py'
137 specific_options.m_training_fraction = 0.9
144 'use_relation_layers':
False,
149 'adversary_steps': 5}
150 specific_options.m_config = json.dumps(keras_dic)
153 basf2_mva.teacher(general_options, specific_options)
156 subprocess.call(
'basf2_mva_evaluate.py '
157 ' -train ' + train_data +
158 ' -data ' + test_data +
159 ' -id ' +
'Deep_Feed_Forward.xml' +
160 ' --output qqbarSuppressionEvaluation.pdf' +