13 import modularAnalysis
as ma
17 from ROOT
import Belle2
20 Check the MVAExpertModule and MVAMultipleExpertsModule modules properly sets the extraInfo fields they should
21 and that the overwrite settings are respected.
26 """Check if the extra Info values are correctly overwritten"""
29 """Create particle list object"""
37 """check the extra info names are what we expect!"""
38 for multiclass
in [
False,
True]:
39 for multiexpert_prefix
in [
'multi_',
'']:
40 for name, value
in [(
'low_never', 0.5),
49 name = multiexpert_prefix+name
51 for index
in range(3):
52 compare_val = index + value
53 extra_info_name = f
'multiclass_{name}_{index}'
54 for p
in self.
plistplist:
55 ei_value = p.getExtraInfo(extra_info_name)
56 assert ei_value == compare_val,\
57 f
'ExtraInfo "{extra_info_name}" value "{ei_value}" not what was expected {compare_val}'
59 ei_value = self.
eventExtraInfoeventExtraInfo.getExtraInfo(extra_info_name)
60 assert ei_value == compare_val,\
61 f
'eventExtraInfo "{extra_info_name}" value "{ei_value}" not what was expected {compare_val}'
63 extra_info_name = name
64 for p
in self.
plistplist:
65 ei_value = p.getExtraInfo(name)
66 assert ei_value == value,\
67 f
'ExtraInfo "{name}" value "{ei_value}" not what was expected "{value}"'
69 assert ei_value == value,\
70 f
'eventExtraInfo "{name}" value "{ei_value}" not what was expected "{value}"'
74 def create_weightfile(name, val):
76 <?xml version="1.0" encoding="utf-8"?>
77 <method>Trivial</method>
78 <weightfile>{name}</weightfile>
79 <treename>tree</treename>
80 <target_variable>isSignal</target_variable>
81 <weight_variable>__weight__</weight_variable>
82 <signal_class>1</signal_class>
83 <max_events>0</max_events>
84 <number_feature_variables>1</number_feature_variables>
85 <variable0>beamE</variable0>
86 <number_spectator_variables>0</number_spectator_variables>
87 <number_data_files>1</number_data_files>
88 <datafile0>train.root</datafile0>
89 <Trivial_version>1</Trivial_version>
90 <Trivial_output>{val}</Trivial_output>
91 <signal_fraction>0.5</signal_fraction>
95 def create_multiclass_weightfile(name, offset_val):
97 <?xml version="1.0" encoding="utf-8"?>
98 <method>Trivial</method>
99 <weightfile>{name}</weightfile>
100 <treename>tree</treename>
101 <target_variable>isSignal</target_variable>
102 <weight_variable>__weight__</weight_variable>
103 <signal_class>1</signal_class>
104 <max_events>0</max_events>
105 <nClasses>3</nClasses>
106 <number_feature_variables>1</number_feature_variables>
107 <variable0>beamE</variable0>
108 <number_spectator_variables>0</number_spectator_variables>
109 <number_data_files>1</number_data_files>
110 <datafile0>train.root</datafile0>
111 <Trivial_version>1</Trivial_version>
112 <Trivial_output>0</Trivial_output>
113 <Trivial_number_of_multiple_outputs>3</Trivial_number_of_multiple_outputs>
114 <Trivial_multiple_output0>{offset_val + 0}</Trivial_multiple_output0>
115 <Trivial_multiple_output1>{offset_val + 1}</Trivial_multiple_output1>
116 <Trivial_multiple_output2>{offset_val + 2}</Trivial_multiple_output2>
117 <signal_fraction>0.5</signal_fraction>
121 if __name__ ==
"__main__":
123 path = basf2.create_path()
126 ma.fillParticleList(
'pi+:test',
'', path=path)
129 for prefix
in [
'',
'multiclass_']:
130 for identifier, extra_info_name, overwrite_option
in [
132 (
'weightfile_mid.xml',
'low_never', 0),
133 (
'weightfile_mid.xml',
'low_always', 2),
134 (
'weightfile_mid.xml',
'low_lower', -1),
135 (
'weightfile_mid.xml',
'low_higher', 1),
136 (
'weightfile_mid.xml',
'high_never', 0),
137 (
'weightfile_mid.xml',
'high_always', 2),
138 (
'weightfile_mid.xml',
'high_lower', -1),
139 (
'weightfile_mid.xml',
'high_higher', 1),
141 (
'weightfile_low.xml',
'low_never', 0),
142 (
'weightfile_low.xml',
'low_always', 2),
143 (
'weightfile_low.xml',
'low_lower', -1),
144 (
'weightfile_low.xml',
'low_higher', 1),
145 (
'weightfile_high.xml',
'high_never', 0),
146 (
'weightfile_high.xml',
'high_always', 2),
147 (
'weightfile_high.xml',
'high_lower', -1),
148 (
'weightfile_high.xml',
'high_higher', 1),
150 extra_info_name = prefix+extra_info_name
151 identifier = prefix+identifier
154 listNames=[
'pi+:test'],
155 extraInfoName=extra_info_name,
156 identifier=identifier,
157 overwriteExistingExtraInfo=overwrite_option)
161 extraInfoName=extra_info_name,
162 identifier=identifier,
163 overwriteExistingExtraInfo=overwrite_option)
165 for identifiers, extra_info_names, overwrite_options
in [
166 ([
'weightfile_mid.xml',
'weightfile_mid.xml'], [
'multi_low_never',
'multi_high_never'], [0, 0]),
167 ([
'weightfile_mid.xml',
'weightfile_mid.xml'], [
'multi_low_always',
'multi_high_always'], [2, 2]),
168 ([
'weightfile_mid.xml',
'weightfile_mid.xml'], [
'multi_low_lower',
'multi_high_lower'], [-1, -1]),
169 ([
'weightfile_mid.xml',
'weightfile_mid.xml'], [
'multi_low_higher',
'multi_high_higher'], [1, 1]),
170 ([
'weightfile_low.xml',
'weightfile_high.xml'], [
'multi_low_never',
'multi_high_never'], [0, 0]),
171 ([
'weightfile_low.xml',
'weightfile_high.xml'], [
'multi_low_always',
'multi_high_always'], [2, 2]),
172 ([
'weightfile_low.xml',
'weightfile_high.xml'], [
'multi_low_higher',
'multi_high_lower'], [1, -1]),
173 ([
'weightfile_low.xml',
'weightfile_high.xml'], [
'multi_low_lower',
'multi_high_higher'], [-1, 1]),
175 extra_info_names = [prefix+x
for x
in extra_info_names]
176 identifiers = [prefix+x
for x
in identifiers]
178 'MVAMultipleExperts',
179 listNames=[
'pi+:test'],
180 extraInfoNames=extra_info_names,
181 identifiers=identifiers,
182 overwriteExistingExtraInfo=overwrite_options)
183 path.add_module(
'MVAMultipleExperts', listNames=[], extraInfoNames=extra_info_names,
184 identifiers=identifiers, overwriteExistingExtraInfo=overwrite_options)
188 with tempfile.TemporaryDirectory()
as tempdir:
189 for name, val
in [(
'weightfile_low.xml', 0.0),
190 (
'weightfile_mid.xml', 0.5),
191 (
'weightfile_high.xml', 1.0)]:
192 with open(name,
"w")
as f:
193 f.write(create_weightfile(name, val))
195 for name, val
in [(
'multiclass_weightfile_low.xml', 0.0),
196 (
'multiclass_weightfile_mid.xml', 0.5),
197 (
'multiclass_weightfile_high.xml', 1.0)]:
198 with open(name,
"w")
as f:
199 f.write(create_multiclass_weightfile(name, val))
201 basf2.process(path, 10)
a (simplified) python wrapper for StoreObjPtr.
def require_file(filename, data_type="", py_case=None)