12import modularAnalysis
as ma
16from ROOT
import Belle2
19Check the MVAExpertModule and MVAMultipleExpertsModule modules properly sets the extraInfo fields they should
20 and that the overwrite settings are respected.
25 """Check if the extra Info values are correctly overwritten"""
28 """Create particle list object"""
38 """check the extra info names are what we expect!"""
39 for multiclass
in [
False,
True]:
40 for multiexpert_prefix
in [
'multi_',
'']:
41 for name, value
in [(
'low_never', 0.5),
50 name = multiexpert_prefix+name
52 for index
in range(3):
53 compare_val = index + value
54 extra_info_name = f
'multiclass_{name}_{index}'
57 ei_value = p.getExtraInfo(extra_info_name)
58 assert ei_value == compare_val,\
59 f
'ExtraInfo "{extra_info_name}" value "{ei_value}" not what was expected {compare_val}'
62 if 'multi_' in extra_info_name:
64 ei_value = p.getDaughter(0).getExtraInfo(extra_info_name)
65 assert ei_value == compare_val,\
66 f
'ExtraInfo "{extra_info_name}" value "{ei_value}" not what was expected {compare_val}'
69 assert ei_value == compare_val,\
70 f
'eventExtraInfo "{extra_info_name}" value "{ei_value}" not what was expected {compare_val}'
72 extra_info_name = name
75 ei_value = p.getExtraInfo(name)
76 assert ei_value == value,\
77 f
'ExtraInfo "{name}" value "{ei_value}" not what was expected "{value}"'
80 if 'multi_' in extra_info_name:
82 ei_value = p.getDaughter(0).getExtraInfo(name)
83 assert ei_value == value,\
84 f
'ExtraInfo "{name}" value "{ei_value}" not what was expected "{value}"'
86 assert ei_value == value,\
87 f
'eventExtraInfo "{name}" value "{ei_value}" not what was expected "{value}"'
91def create_weightfile(name, val):
93 <?xml version="1.0" encoding=
"utf-8"?>
94 <method>Trivial</method>
95 <weightfile>{name}</weightfile>
96 <treename>tree</treename>
97 <target_variable>isSignal</target_variable>
98 <weight_variable>__weight__</weight_variable>
99 <signal_class>1</signal_class>
100 <max_events>0</max_events>
101 <number_feature_variables>1</number_feature_variables>
102 <variable0>beamE</variable0>
103 <number_spectator_variables>0</number_spectator_variables>
104 <number_data_files>1</number_data_files>
105 <datafile0>train.root</datafile0>
106 <Trivial_version>1</Trivial_version>
107 <Trivial_output>{val}</Trivial_output>
108 <signal_fraction>0.5</signal_fraction>
112def create_multiclass_weightfile(name, offset_val):
114 <?xml version="1.0" encoding=
"utf-8"?>
115 <method>Trivial</method>
116 <weightfile>{name}</weightfile>
117 <treename>tree</treename>
118 <target_variable>isSignal</target_variable>
119 <weight_variable>__weight__</weight_variable>
120 <signal_class>1</signal_class>
121 <max_events>0</max_events>
122 <nClasses>3</nClasses>
123 <number_feature_variables>1</number_feature_variables>
124 <variable0>beamE</variable0>
125 <number_spectator_variables>0</number_spectator_variables>
126 <number_data_files>1</number_data_files>
127 <datafile0>train.root</datafile0>
128 <Trivial_version>1</Trivial_version>
129 <Trivial_output>0</Trivial_output>
130 <Trivial_number_of_multiple_outputs>3</Trivial_number_of_multiple_outputs>
131 <Trivial_multiple_output0>{offset_val + 0}</Trivial_multiple_output0>
132 <Trivial_multiple_output1>{offset_val + 1}</Trivial_multiple_output1>
133 <Trivial_multiple_output2>{offset_val + 2}</Trivial_multiple_output2>
134 <signal_fraction>0.5</signal_fraction>
138if __name__ == "__main__":
140 path = basf2.create_path()
143 ma.fillParticleList(
'pi+:test',
'', path=path)
144 ma.fillParticleList(
'Lambda0:test -> p+ pi-',
'', path=path)
147 for prefix
in [
'',
'multiclass_']:
148 for identifier, extra_info_name, overwrite_option
in [
150 (
'weightfile_mid.xml',
'low_never', 0),
151 (
'weightfile_mid.xml',
'low_always', 2),
152 (
'weightfile_mid.xml',
'low_lower', -1),
153 (
'weightfile_mid.xml',
'low_higher', 1),
154 (
'weightfile_mid.xml',
'high_never', 0),
155 (
'weightfile_mid.xml',
'high_always', 2),
156 (
'weightfile_mid.xml',
'high_lower', -1),
157 (
'weightfile_mid.xml',
'high_higher', 1),
159 (
'weightfile_low.xml',
'low_never', 0),
160 (
'weightfile_low.xml',
'low_always', 2),
161 (
'weightfile_low.xml',
'low_lower', -1),
162 (
'weightfile_low.xml',
'low_higher', 1),
163 (
'weightfile_high.xml',
'high_never', 0),
164 (
'weightfile_high.xml',
'high_always', 2),
165 (
'weightfile_high.xml',
'high_lower', -1),
166 (
'weightfile_high.xml',
'high_higher', 1),
168 extra_info_name = prefix+extra_info_name
169 identifier = prefix+identifier
172 listNames=[
'pi+:test'],
173 extraInfoName=extra_info_name,
174 identifier=identifier,
175 overwriteExistingExtraInfo=overwrite_option)
178 listNames=[
'Lambda0:test -> ^p+ pi-'],
179 extraInfoName=extra_info_name,
180 identifier=identifier,
181 overwriteExistingExtraInfo=overwrite_option)
185 extraInfoName=extra_info_name,
186 identifier=identifier,
187 overwriteExistingExtraInfo=overwrite_option)
189 for identifiers, extra_info_names, overwrite_options
in [
190 ([
'weightfile_mid.xml',
'weightfile_mid.xml'], [
'multi_low_never',
'multi_high_never'], [0, 0]),
191 ([
'weightfile_mid.xml',
'weightfile_mid.xml'], [
'multi_low_always',
'multi_high_always'], [2, 2]),
192 ([
'weightfile_mid.xml',
'weightfile_mid.xml'], [
'multi_low_lower',
'multi_high_lower'], [-1, -1]),
193 ([
'weightfile_mid.xml',
'weightfile_mid.xml'], [
'multi_low_higher',
'multi_high_higher'], [1, 1]),
194 ([
'weightfile_low.xml',
'weightfile_high.xml'], [
'multi_low_never',
'multi_high_never'], [0, 0]),
195 ([
'weightfile_low.xml',
'weightfile_high.xml'], [
'multi_low_always',
'multi_high_always'], [2, 2]),
196 ([
'weightfile_low.xml',
'weightfile_high.xml'], [
'multi_low_higher',
'multi_high_lower'], [1, -1]),
197 ([
'weightfile_low.xml',
'weightfile_high.xml'], [
'multi_low_lower',
'multi_high_higher'], [-1, 1]),
199 extra_info_names = [prefix+x
for x
in extra_info_names]
200 identifiers = [prefix+x
for x
in identifiers]
202 'MVAMultipleExperts',
203 listNames=[
'pi+:test'],
204 extraInfoNames=extra_info_names,
205 identifiers=identifiers,
206 overwriteExistingExtraInfo=overwrite_options)
207 path.add_module(
'MVAMultipleExperts', listNames=[], extraInfoNames=extra_info_names,
208 identifiers=identifiers, overwriteExistingExtraInfo=overwrite_options)
212 with tempfile.TemporaryDirectory()
as tempdir:
213 for name, val
in [(
'weightfile_low.xml', 0.0),
214 (
'weightfile_mid.xml', 0.5),
215 (
'weightfile_high.xml', 1.0)]:
216 with open(name,
"w")
as f:
217 f.write(create_weightfile(name, val))
219 for name, val
in [(
'multiclass_weightfile_low.xml', 0.0),
220 (
'multiclass_weightfile_mid.xml', 0.5),
221 (
'multiclass_weightfile_high.xml', 1.0)]:
222 with open(name,
"w")
as f:
223 f.write(create_multiclass_weightfile(name, val))
225 basf2.process(path, 10)
a (simplified) python wrapper for StoreObjPtr.
def require_file(filename, data_type="", py_case=None)