Belle II Software  release-08-01-10
test_MVAExpertModule.py
1 #!/usr/bin/env python3
2 
3 
10 
11 import basf2
12 import modularAnalysis as ma
13 import tempfile
14 import b2test_utils
15 
16 from ROOT import Belle2
17 
18 """
19 Check the MVAExpertModule and MVAMultipleExpertsModule modules properly sets the extraInfo fields they should
20  and that the overwrite settings are respected.
21 """
22 
23 
24 class MVAExtraInfoChecker(basf2.Module):
25  """Check if the extra Info values are correctly overwritten"""
26 
27  def initialize(self):
28  """Create particle list object"""
29 
30  self.plistplist = Belle2.PyStoreObj("pi+:test")
31 
32  self.comp_plistcomp_plist = Belle2.PyStoreObj("Lambda0:test")
33 
34 
35  self.eventExtraInfoeventExtraInfo = Belle2.PyStoreObj("EventExtraInfo")
36 
37  def event(self):
38  """check the extra info names are what we expect!"""
39  for multiclass in [False, True]:
40  for multiexpert_prefix in ['multi_', '']:
41  for name, value in [('low_never', 0.5),
42  ('low_always', 0.0),
43  ('low_higher', 0.5),
44  ('low_lower', 0.0),
45 
46  ('high_never', 0.5),
47  ('high_always', 1.0),
48  ('high_higher', 1.0),
49  ('high_lower', 0.5)]:
50  name = multiexpert_prefix+name
51  if multiclass:
52  for index in range(3):
53  compare_val = index + value
54  extra_info_name = f'multiclass_{name}_{index}'
55  # FSP use case
56  for p in self.plistplist:
57  ei_value = p.getExtraInfo(extra_info_name)
58  assert ei_value == compare_val,\
59  f'ExtraInfo "{extra_info_name}" value "{ei_value}" not what was expected {compare_val}'
60  # Decay string use case
61  for p in self.comp_plistcomp_plist:
62  if 'multi_' in extra_info_name:
63  continue
64  ei_value = p.getDaughter(0).getExtraInfo(extra_info_name)
65  assert ei_value == compare_val,\
66  f'ExtraInfo "{extra_info_name}" value "{ei_value}" not what was expected {compare_val}'
67 
68  ei_value = self.eventExtraInfoeventExtraInfo.getExtraInfo(extra_info_name)
69  assert ei_value == compare_val,\
70  f'eventExtraInfo "{extra_info_name}" value "{ei_value}" not what was expected {compare_val}'
71  else:
72  extra_info_name = name
73  # FSP use case
74  for p in self.plistplist:
75  ei_value = p.getExtraInfo(name)
76  assert ei_value == value,\
77  f'ExtraInfo "{name}" value "{ei_value}" not what was expected "{value}"'
78  # Decay string use case
79  for p in self.comp_plistcomp_plist:
80  if 'multi_' in extra_info_name:
81  continue
82  ei_value = p.getDaughter(0).getExtraInfo(name)
83  assert ei_value == value,\
84  f'ExtraInfo "{name}" value "{ei_value}" not what was expected "{value}"'
85  ei_value = self.eventExtraInfoeventExtraInfo.getExtraInfo(name)
86  assert ei_value == value,\
87  f'eventExtraInfo "{name}" value "{ei_value}" not what was expected "{value}"'
88 
89 
90 # prepare the weightfiles
91 def create_weightfile(name, val):
92  return f"""
93  <?xml version="1.0" encoding="utf-8"?>
94  <method>Trivial</method>
95  <weightfile>{name}</weightfile>
96  <treename>tree</treename>
97  <target_variable>isSignal</target_variable>
98  <weight_variable>__weight__</weight_variable>
99  <signal_class>1</signal_class>
100  <max_events>0</max_events>
101  <number_feature_variables>1</number_feature_variables>
102  <variable0>beamE</variable0>
103  <number_spectator_variables>0</number_spectator_variables>
104  <number_data_files>1</number_data_files>
105  <datafile0>train.root</datafile0>
106  <Trivial_version>1</Trivial_version>
107  <Trivial_output>{val}</Trivial_output>
108  <signal_fraction>0.5</signal_fraction>
109  """
110 
111 
112 def create_multiclass_weightfile(name, offset_val):
113  return f"""
114  <?xml version="1.0" encoding="utf-8"?>
115  <method>Trivial</method>
116  <weightfile>{name}</weightfile>
117  <treename>tree</treename>
118  <target_variable>isSignal</target_variable>
119  <weight_variable>__weight__</weight_variable>
120  <signal_class>1</signal_class>
121  <max_events>0</max_events>
122  <nClasses>3</nClasses>
123  <number_feature_variables>1</number_feature_variables>
124  <variable0>beamE</variable0>
125  <number_spectator_variables>0</number_spectator_variables>
126  <number_data_files>1</number_data_files>
127  <datafile0>train.root</datafile0>
128  <Trivial_version>1</Trivial_version>
129  <Trivial_output>0</Trivial_output>
130  <Trivial_number_of_multiple_outputs>3</Trivial_number_of_multiple_outputs>
131  <Trivial_multiple_output0>{offset_val + 0}</Trivial_multiple_output0>
132  <Trivial_multiple_output1>{offset_val + 1}</Trivial_multiple_output1>
133  <Trivial_multiple_output2>{offset_val + 2}</Trivial_multiple_output2>
134  <signal_fraction>0.5</signal_fraction>
135  """
136 
137 
138 if __name__ == "__main__":
139 
140  path = basf2.create_path()
141  ma.inputMdst(b2test_utils.require_file('analysis/tests/mdst.root'), path=path)
142 
143  ma.fillParticleList('pi+:test', '', path=path)
144  ma.fillParticleList('Lambda0:test -> p+ pi-', '', path=path)
145 
146  # test all combinations of [single expert, multiexpert] x [binary, multiclass]
147  for prefix in ['', 'multiclass_']:
148  for identifier, extra_info_name, overwrite_option in [
149  # set the initial values
150  ('weightfile_mid.xml', 'low_never', 0),
151  ('weightfile_mid.xml', 'low_always', 2),
152  ('weightfile_mid.xml', 'low_lower', -1),
153  ('weightfile_mid.xml', 'low_higher', 1),
154  ('weightfile_mid.xml', 'high_never', 0),
155  ('weightfile_mid.xml', 'high_always', 2),
156  ('weightfile_mid.xml', 'high_lower', -1),
157  ('weightfile_mid.xml', 'high_higher', 1),
158  # try to overwrite them
159  ('weightfile_low.xml', 'low_never', 0),
160  ('weightfile_low.xml', 'low_always', 2),
161  ('weightfile_low.xml', 'low_lower', -1),
162  ('weightfile_low.xml', 'low_higher', 1),
163  ('weightfile_high.xml', 'high_never', 0),
164  ('weightfile_high.xml', 'high_always', 2),
165  ('weightfile_high.xml', 'high_lower', -1),
166  ('weightfile_high.xml', 'high_higher', 1),
167  ]:
168  extra_info_name = prefix+extra_info_name
169  identifier = prefix+identifier
170  path.add_module(
171  'MVAExpert',
172  listNames=['pi+:test'],
173  extraInfoName=extra_info_name,
174  identifier=identifier,
175  overwriteExistingExtraInfo=overwrite_option)
176  path.add_module(
177  'MVAExpert',
178  listNames=['Lambda0:test -> ^p+ pi-'],
179  extraInfoName=extra_info_name,
180  identifier=identifier,
181  overwriteExistingExtraInfo=overwrite_option)
182  path.add_module(
183  'MVAExpert',
184  listNames=[],
185  extraInfoName=extra_info_name,
186  identifier=identifier,
187  overwriteExistingExtraInfo=overwrite_option)
188 
189  for identifiers, extra_info_names, overwrite_options in [
190  (['weightfile_mid.xml', 'weightfile_mid.xml'], ['multi_low_never', 'multi_high_never'], [0, 0]),
191  (['weightfile_mid.xml', 'weightfile_mid.xml'], ['multi_low_always', 'multi_high_always'], [2, 2]),
192  (['weightfile_mid.xml', 'weightfile_mid.xml'], ['multi_low_lower', 'multi_high_lower'], [-1, -1]),
193  (['weightfile_mid.xml', 'weightfile_mid.xml'], ['multi_low_higher', 'multi_high_higher'], [1, 1]),
194  (['weightfile_low.xml', 'weightfile_high.xml'], ['multi_low_never', 'multi_high_never'], [0, 0]),
195  (['weightfile_low.xml', 'weightfile_high.xml'], ['multi_low_always', 'multi_high_always'], [2, 2]),
196  (['weightfile_low.xml', 'weightfile_high.xml'], ['multi_low_higher', 'multi_high_lower'], [1, -1]),
197  (['weightfile_low.xml', 'weightfile_high.xml'], ['multi_low_lower', 'multi_high_higher'], [-1, 1]),
198  ]:
199  extra_info_names = [prefix+x for x in extra_info_names]
200  identifiers = [prefix+x for x in identifiers]
201  path.add_module(
202  'MVAMultipleExperts',
203  listNames=['pi+:test'],
204  extraInfoNames=extra_info_names,
205  identifiers=identifiers,
206  overwriteExistingExtraInfo=overwrite_options)
207  path.add_module('MVAMultipleExperts', listNames=[], extraInfoNames=extra_info_names,
208  identifiers=identifiers, overwriteExistingExtraInfo=overwrite_options)
209 
210  path.add_module(MVAExtraInfoChecker())
211 
212  with tempfile.TemporaryDirectory() as tempdir:
213  for name, val in [('weightfile_low.xml', 0.0),
214  ('weightfile_mid.xml', 0.5),
215  ('weightfile_high.xml', 1.0)]:
216  with open(name, "w") as f:
217  f.write(create_weightfile(name, val))
218 
219  for name, val in [('multiclass_weightfile_low.xml', 0.0),
220  ('multiclass_weightfile_mid.xml', 0.5),
221  ('multiclass_weightfile_high.xml', 1.0)]:
222  with open(name, "w") as f:
223  f.write(create_multiclass_weightfile(name, val))
224 
225  basf2.process(path, 10)
a (simplified) python wrapper for StoreObjPtr.
Definition: PyStoreObj.h:67
comp_plist
Composite particle list object.
def require_file(filename, data_type="", py_case=None)
Definition: __init__.py:54