36def choose_input_features(use_vertex_features=True, use_charge_and_ROE_features=False, use_continuum_features=1):
38 Function to return all names of input features.
39 :param use_vertex_features: If Vertex info should be included.
40 :param use_charge_and_ROE_features: If charge and ROE should be included as extra features(information already
41 included in group structure). This option is only nevessary when using Relation Layers.
42 :param use_continuum_features: Use old Continuum Features (0: No, 1: Yes, 2: Use only the old features)
43 :return: Array of feature names
51 'KSFWVariables(pt_sum)',
53 'KSFWVariables(hso00)',
54 'KSFWVariables(hso02)',
55 'KSFWVariables(hso04)',
56 'KSFWVariables(hso10)',
57 'KSFWVariables(hso12)',
58 'KSFWVariables(hso14)',
59 'KSFWVariables(hso20)',
60 'KSFWVariables(hso22)',
61 'KSFWVariables(hso24)',
62 'KSFWVariables(hoo0)',
63 'KSFWVariables(hoo1)',
64 'KSFWVariables(hoo2)',
65 'KSFWVariables(hoo3)',
66 'KSFWVariables(hoo4)',
77 if use_continuum_features == 2:
80 basic_variables = [
'p',
'phi',
'cosTheta',
'pErr',
'phiErr',
'cosThetaErr']
81 vertex_variables = [
'distance',
'dphi',
'dcosTheta']
83 cluster_specific_variables = [
'clusterNHits',
'clusterTiming',
'clusterE9E25',
'clusterReg']
84 track_specific_variables = [
'kaonID',
'electronID',
'muonID',
'protonID',
'pValue',
'nCDCHits']
86 if use_charge_and_ROE_features:
87 cluster_specific_variables += [
'isInRestOfEvent']
88 track_specific_variables += [
'isInRestOfEvent',
'charge']
90 cluster_specific_variables += [
'thrustsig' + var
for var
in basic_variables]
91 track_specific_variables += [
'thrustsig' + var
for var
in basic_variables]
93 if use_vertex_features:
94 track_specific_variables += [
'thrustsig' + var
for var
in vertex_variables]
96 cluster_lists = [
'Csig',
'Croe']
97 track_lists = [
'TPsig',
'TMsig',
'TProe',
'TMroe']
100 for plist
in track_lists:
101 for rank
in range(5):
102 for var
in track_specific_variables:
103 variables.append(f
'{var}_{plist}{rank}')
105 for plist
in cluster_lists:
106 for rank
in range(10):
107 for var
in cluster_specific_variables:
108 variables.append(f
'{var}_{plist}{rank}')
110 if use_continuum_features:
116if __name__ ==
"__main__":
118 if not os.getenv(
'BELLE2_EXAMPLES_DATA_DIR'):
119 b2.B2FATAL(
"You need the example data installed. Run `b2install-data example` in terminal for it.")
121 path = os.getenv(
'BELLE2_EXAMPLES_DATA_DIR')+
'/mva/'
123 train_data = path +
'DNN_train.root'
124 test_data = path +
'DNN_test.root'
126 general_options = basf2_mva.GeneralOptions()
127 general_options.m_datafiles = basf2_mva.vector(train_data)
128 general_options.m_treename =
"tree"
129 general_options.m_identifier =
"Deep_Feed_Forward.xml"
130 general_options.m_variables = basf2_mva.vector(*choose_input_features(
True,
False, 1))
131 general_options.m_spectators = basf2_mva.vector(
'Mbc',
'DeltaZ')
132 general_options.m_target_variable =
"isNotContinuumEvent"
134 specific_options = basf2_mva.PythonOptions()
135 specific_options.m_framework =
"keras"
136 specific_options.m_steering_file =
'analysis/examples/mva/B2A714-DeepContinuumSuppression_MVAModel.py'
137 specific_options.m_training_fraction = 0.9
144 'use_relation_layers':
False,
149 'adversary_steps': 5}
150 specific_options.m_config = json.dumps(keras_dic)
153 basf2_mva.teacher(general_options, specific_options)
156 subprocess.call(
'basf2_mva_evaluate.py '
157 ' -train ' + train_data +
158 ' -data ' + test_data +
159 ' -id ' +
'Deep_Feed_Forward.xml' +
160 ' --output qqbarSuppressionEvaluation.pdf' +