Belle II Software light-2406-ragdoll
test_MVAExpertModule.py
1#!/usr/bin/env python3
2
3
10
11import basf2
12import modularAnalysis as ma
13import tempfile
14import b2test_utils
15
16from ROOT import Belle2
17
18"""
19Check the MVAExpertModule and MVAMultipleExpertsModule modules properly sets the extraInfo fields they should
20 and that the overwrite settings are respected.
21"""
22
23
24class MVAExtraInfoChecker(basf2.Module):
25 """Check if the extra Info values are correctly overwritten"""
26
27 def initialize(self):
28 """Create particle list object"""
29
30 self.plist = Belle2.PyStoreObj("pi+:test")
31
32 self.comp_plist = Belle2.PyStoreObj("Lambda0:test")
33
34
35 self.eventExtraInfo = Belle2.PyStoreObj("EventExtraInfo")
36
37 def event(self):
38 """check the extra info names are what we expect!"""
39 for multiclass in [False, True]:
40 for multiexpert_prefix in ['multi_', '']:
41 for name, value in [('low_never', 0.5),
42 ('low_always', 0.0),
43 ('low_higher', 0.5),
44 ('low_lower', 0.0),
45
46 ('high_never', 0.5),
47 ('high_always', 1.0),
48 ('high_higher', 1.0),
49 ('high_lower', 0.5)]:
50 name = multiexpert_prefix+name
51 if multiclass:
52 for index in range(3):
53 compare_val = index + value
54 extra_info_name = f'multiclass_{name}_{index}'
55 # FSP use case
56 for p in self.plist:
57 ei_value = p.getExtraInfo(extra_info_name)
58 assert ei_value == compare_val,\
59 f'ExtraInfo "{extra_info_name}" value "{ei_value}" not what was expected {compare_val}'
60 # Decay string use case
61 for p in self.comp_plist:
62 if 'multi_' in extra_info_name:
63 continue
64 ei_value = p.getDaughter(0).getExtraInfo(extra_info_name)
65 assert ei_value == compare_val,\
66 f'ExtraInfo "{extra_info_name}" value "{ei_value}" not what was expected {compare_val}'
67
68 ei_value = self.eventExtraInfo.getExtraInfo(extra_info_name)
69 assert ei_value == compare_val,\
70 f'eventExtraInfo "{extra_info_name}" value "{ei_value}" not what was expected {compare_val}'
71 else:
72 extra_info_name = name
73 # FSP use case
74 for p in self.plist:
75 ei_value = p.getExtraInfo(name)
76 assert ei_value == value,\
77 f'ExtraInfo "{name}" value "{ei_value}" not what was expected "{value}"'
78 # Decay string use case
79 for p in self.comp_plist:
80 if 'multi_' in extra_info_name:
81 continue
82 ei_value = p.getDaughter(0).getExtraInfo(name)
83 assert ei_value == value,\
84 f'ExtraInfo "{name}" value "{ei_value}" not what was expected "{value}"'
85 ei_value = self.eventExtraInfo.getExtraInfo(name)
86 assert ei_value == value,\
87 f'eventExtraInfo "{name}" value "{ei_value}" not what was expected "{value}"'
88
89
90# prepare the weightfiles
91def create_weightfile(name, val):
92 return f"""
93 <?xml version="1.0" encoding="utf-8"?>
94 <method>Trivial</method>
95 <weightfile>{name}</weightfile>
96 <treename>tree</treename>
97 <target_variable>isSignal</target_variable>
98 <weight_variable>__weight__</weight_variable>
99 <signal_class>1</signal_class>
100 <max_events>0</max_events>
101 <number_feature_variables>1</number_feature_variables>
102 <variable0>beamE</variable0>
103 <number_spectator_variables>0</number_spectator_variables>
104 <number_data_files>1</number_data_files>
105 <datafile0>train.root</datafile0>
106 <Trivial_version>1</Trivial_version>
107 <Trivial_output>{val}</Trivial_output>
108 <signal_fraction>0.5</signal_fraction>
109 """
110
111
112def create_multiclass_weightfile(name, offset_val):
113 return f"""
114 <?xml version="1.0" encoding="utf-8"?>
115 <method>Trivial</method>
116 <weightfile>{name}</weightfile>
117 <treename>tree</treename>
118 <target_variable>isSignal</target_variable>
119 <weight_variable>__weight__</weight_variable>
120 <signal_class>1</signal_class>
121 <max_events>0</max_events>
122 <nClasses>3</nClasses>
123 <number_feature_variables>1</number_feature_variables>
124 <variable0>beamE</variable0>
125 <number_spectator_variables>0</number_spectator_variables>
126 <number_data_files>1</number_data_files>
127 <datafile0>train.root</datafile0>
128 <Trivial_version>1</Trivial_version>
129 <Trivial_output>0</Trivial_output>
130 <Trivial_number_of_multiple_outputs>3</Trivial_number_of_multiple_outputs>
131 <Trivial_multiple_output0>{offset_val + 0}</Trivial_multiple_output0>
132 <Trivial_multiple_output1>{offset_val + 1}</Trivial_multiple_output1>
133 <Trivial_multiple_output2>{offset_val + 2}</Trivial_multiple_output2>
134 <signal_fraction>0.5</signal_fraction>
135 """
136
137
138if __name__ == "__main__":
139
140 path = basf2.create_path()
141 ma.inputMdst(b2test_utils.require_file('analysis/tests/mdst.root'), path=path)
142
143 ma.fillParticleList('pi+:test', '', path=path)
144 ma.fillParticleList('Lambda0:test -> p+ pi-', '', path=path)
145
146 # test all combinations of [single expert, multiexpert] x [binary, multiclass]
147 for prefix in ['', 'multiclass_']:
148 for identifier, extra_info_name, overwrite_option in [
149 # set the initial values
150 ('weightfile_mid.xml', 'low_never', 0),
151 ('weightfile_mid.xml', 'low_always', 2),
152 ('weightfile_mid.xml', 'low_lower', -1),
153 ('weightfile_mid.xml', 'low_higher', 1),
154 ('weightfile_mid.xml', 'high_never', 0),
155 ('weightfile_mid.xml', 'high_always', 2),
156 ('weightfile_mid.xml', 'high_lower', -1),
157 ('weightfile_mid.xml', 'high_higher', 1),
158 # try to overwrite them
159 ('weightfile_low.xml', 'low_never', 0),
160 ('weightfile_low.xml', 'low_always', 2),
161 ('weightfile_low.xml', 'low_lower', -1),
162 ('weightfile_low.xml', 'low_higher', 1),
163 ('weightfile_high.xml', 'high_never', 0),
164 ('weightfile_high.xml', 'high_always', 2),
165 ('weightfile_high.xml', 'high_lower', -1),
166 ('weightfile_high.xml', 'high_higher', 1),
167 ]:
168 extra_info_name = prefix+extra_info_name
169 identifier = prefix+identifier
170 path.add_module(
171 'MVAExpert',
172 listNames=['pi+:test'],
173 extraInfoName=extra_info_name,
174 identifier=identifier,
175 overwriteExistingExtraInfo=overwrite_option)
176 path.add_module(
177 'MVAExpert',
178 listNames=['Lambda0:test -> ^p+ pi-'],
179 extraInfoName=extra_info_name,
180 identifier=identifier,
181 overwriteExistingExtraInfo=overwrite_option)
182 path.add_module(
183 'MVAExpert',
184 listNames=[],
185 extraInfoName=extra_info_name,
186 identifier=identifier,
187 overwriteExistingExtraInfo=overwrite_option)
188
189 for identifiers, extra_info_names, overwrite_options in [
190 (['weightfile_mid.xml', 'weightfile_mid.xml'], ['multi_low_never', 'multi_high_never'], [0, 0]),
191 (['weightfile_mid.xml', 'weightfile_mid.xml'], ['multi_low_always', 'multi_high_always'], [2, 2]),
192 (['weightfile_mid.xml', 'weightfile_mid.xml'], ['multi_low_lower', 'multi_high_lower'], [-1, -1]),
193 (['weightfile_mid.xml', 'weightfile_mid.xml'], ['multi_low_higher', 'multi_high_higher'], [1, 1]),
194 (['weightfile_low.xml', 'weightfile_high.xml'], ['multi_low_never', 'multi_high_never'], [0, 0]),
195 (['weightfile_low.xml', 'weightfile_high.xml'], ['multi_low_always', 'multi_high_always'], [2, 2]),
196 (['weightfile_low.xml', 'weightfile_high.xml'], ['multi_low_higher', 'multi_high_lower'], [1, -1]),
197 (['weightfile_low.xml', 'weightfile_high.xml'], ['multi_low_lower', 'multi_high_higher'], [-1, 1]),
198 ]:
199 extra_info_names = [prefix+x for x in extra_info_names]
200 identifiers = [prefix+x for x in identifiers]
201 path.add_module(
202 'MVAMultipleExperts',
203 listNames=['pi+:test'],
204 extraInfoNames=extra_info_names,
205 identifiers=identifiers,
206 overwriteExistingExtraInfo=overwrite_options)
207 path.add_module('MVAMultipleExperts', listNames=[], extraInfoNames=extra_info_names,
208 identifiers=identifiers, overwriteExistingExtraInfo=overwrite_options)
209
210 path.add_module(MVAExtraInfoChecker())
211
212 with tempfile.TemporaryDirectory() as tempdir:
213 for name, val in [('weightfile_low.xml', 0.0),
214 ('weightfile_mid.xml', 0.5),
215 ('weightfile_high.xml', 1.0)]:
216 with open(name, "w") as f:
217 f.write(create_weightfile(name, val))
218
219 for name, val in [('multiclass_weightfile_low.xml', 0.0),
220 ('multiclass_weightfile_mid.xml', 0.5),
221 ('multiclass_weightfile_high.xml', 1.0)]:
222 with open(name, "w") as f:
223 f.write(create_multiclass_weightfile(name, val))
224
225 basf2.process(path, 10)
a (simplified) python wrapper for StoreObjPtr.
Definition: PyStoreObj.h:67
comp_plist
Composite particle list object.
def require_file(filename, data_type="", py_case=None)
Definition: __init__.py:54