Belle II Software  light-2212-foldex
test_MVAExpertModule.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3 
4 
11 
12 import basf2
13 import modularAnalysis as ma
14 import tempfile
15 import b2test_utils
16 
17 from ROOT import Belle2
18 
19 """
20 Check the MVAExpertModule and MVAMultipleExpertsModule modules properly sets the extraInfo fields they should
21  and that the overwrite settings are respected.
22 """
23 
24 
25 class MVAExtraInfoChecker(basf2.Module):
26  """Check if the extra Info values are correctly overwritten"""
27 
28  def initialize(self):
29  """Create particle list object"""
30 
31  self.plistplist = Belle2.PyStoreObj("pi+:test")
32 
33 
34  self.eventExtraInfoeventExtraInfo = Belle2.PyStoreObj("EventExtraInfo")
35 
36  def event(self):
37  """check the extra info names are what we expect!"""
38  for multiclass in [False, True]:
39  for multiexpert_prefix in ['multi_', '']:
40  for name, value in [('low_never', 0.5),
41  ('low_always', 0.0),
42  ('low_higher', 0.5),
43  ('low_lower', 0.0),
44 
45  ('high_never', 0.5),
46  ('high_always', 1.0),
47  ('high_higher', 1.0),
48  ('high_lower', 0.5)]:
49  name = multiexpert_prefix+name
50  if multiclass:
51  for index in range(3):
52  compare_val = index + value
53  extra_info_name = f'multiclass_{name}_{index}'
54  for p in self.plistplist:
55  ei_value = p.getExtraInfo(extra_info_name)
56  assert ei_value == compare_val,\
57  f'ExtraInfo "{extra_info_name}" value "{ei_value}" not what was expected {compare_val}'
58 
59  ei_value = self.eventExtraInfoeventExtraInfo.getExtraInfo(extra_info_name)
60  assert ei_value == compare_val,\
61  f'eventExtraInfo "{extra_info_name}" value "{ei_value}" not what was expected {compare_val}'
62  else:
63  extra_info_name = name
64  for p in self.plistplist:
65  ei_value = p.getExtraInfo(name)
66  assert ei_value == value,\
67  f'ExtraInfo "{name}" value "{ei_value}" not what was expected "{value}"'
68  ei_value = self.eventExtraInfoeventExtraInfo.getExtraInfo(name)
69  assert ei_value == value,\
70  f'eventExtraInfo "{name}" value "{ei_value}" not what was expected "{value}"'
71 
72 
73 # prepare the weightfiles
74 def create_weightfile(name, val):
75  return f"""
76  <?xml version="1.0" encoding="utf-8"?>
77  <method>Trivial</method>
78  <weightfile>{name}</weightfile>
79  <treename>tree</treename>
80  <target_variable>isSignal</target_variable>
81  <weight_variable>__weight__</weight_variable>
82  <signal_class>1</signal_class>
83  <max_events>0</max_events>
84  <number_feature_variables>1</number_feature_variables>
85  <variable0>beamE</variable0>
86  <number_spectator_variables>0</number_spectator_variables>
87  <number_data_files>1</number_data_files>
88  <datafile0>train.root</datafile0>
89  <Trivial_version>1</Trivial_version>
90  <Trivial_output>{val}</Trivial_output>
91  <signal_fraction>0.5</signal_fraction>
92  """
93 
94 
95 def create_multiclass_weightfile(name, offset_val):
96  return f"""
97  <?xml version="1.0" encoding="utf-8"?>
98  <method>Trivial</method>
99  <weightfile>{name}</weightfile>
100  <treename>tree</treename>
101  <target_variable>isSignal</target_variable>
102  <weight_variable>__weight__</weight_variable>
103  <signal_class>1</signal_class>
104  <max_events>0</max_events>
105  <nClasses>3</nClasses>
106  <number_feature_variables>1</number_feature_variables>
107  <variable0>beamE</variable0>
108  <number_spectator_variables>0</number_spectator_variables>
109  <number_data_files>1</number_data_files>
110  <datafile0>train.root</datafile0>
111  <Trivial_version>1</Trivial_version>
112  <Trivial_output>0</Trivial_output>
113  <Trivial_number_of_multiple_outputs>3</Trivial_number_of_multiple_outputs>
114  <Trivial_multiple_output0>{offset_val + 0}</Trivial_multiple_output0>
115  <Trivial_multiple_output1>{offset_val + 1}</Trivial_multiple_output1>
116  <Trivial_multiple_output2>{offset_val + 2}</Trivial_multiple_output2>
117  <signal_fraction>0.5</signal_fraction>
118  """
119 
120 
121 if __name__ == "__main__":
122 
123  path = basf2.create_path()
124  ma.inputMdst(b2test_utils.require_file('analysis/tests/mdst.root'), path=path)
125 
126  ma.fillParticleList('pi+:test', '', path=path)
127 
128  # test all combinations of [single expert, multiexpert] x [binary, multiclass]
129  for prefix in ['', 'multiclass_']:
130  for identifier, extra_info_name, overwrite_option in [
131  # set the initial values
132  ('weightfile_mid.xml', 'low_never', 0),
133  ('weightfile_mid.xml', 'low_always', 2),
134  ('weightfile_mid.xml', 'low_lower', -1),
135  ('weightfile_mid.xml', 'low_higher', 1),
136  ('weightfile_mid.xml', 'high_never', 0),
137  ('weightfile_mid.xml', 'high_always', 2),
138  ('weightfile_mid.xml', 'high_lower', -1),
139  ('weightfile_mid.xml', 'high_higher', 1),
140  # try to overwrite them
141  ('weightfile_low.xml', 'low_never', 0),
142  ('weightfile_low.xml', 'low_always', 2),
143  ('weightfile_low.xml', 'low_lower', -1),
144  ('weightfile_low.xml', 'low_higher', 1),
145  ('weightfile_high.xml', 'high_never', 0),
146  ('weightfile_high.xml', 'high_always', 2),
147  ('weightfile_high.xml', 'high_lower', -1),
148  ('weightfile_high.xml', 'high_higher', 1),
149  ]:
150  extra_info_name = prefix+extra_info_name
151  identifier = prefix+identifier
152  path.add_module(
153  'MVAExpert',
154  listNames=['pi+:test'],
155  extraInfoName=extra_info_name,
156  identifier=identifier,
157  overwriteExistingExtraInfo=overwrite_option)
158  path.add_module(
159  'MVAExpert',
160  listNames=[],
161  extraInfoName=extra_info_name,
162  identifier=identifier,
163  overwriteExistingExtraInfo=overwrite_option)
164 
165  for identifiers, extra_info_names, overwrite_options in [
166  (['weightfile_mid.xml', 'weightfile_mid.xml'], ['multi_low_never', 'multi_high_never'], [0, 0]),
167  (['weightfile_mid.xml', 'weightfile_mid.xml'], ['multi_low_always', 'multi_high_always'], [2, 2]),
168  (['weightfile_mid.xml', 'weightfile_mid.xml'], ['multi_low_lower', 'multi_high_lower'], [-1, -1]),
169  (['weightfile_mid.xml', 'weightfile_mid.xml'], ['multi_low_higher', 'multi_high_higher'], [1, 1]),
170  (['weightfile_low.xml', 'weightfile_high.xml'], ['multi_low_never', 'multi_high_never'], [0, 0]),
171  (['weightfile_low.xml', 'weightfile_high.xml'], ['multi_low_always', 'multi_high_always'], [2, 2]),
172  (['weightfile_low.xml', 'weightfile_high.xml'], ['multi_low_higher', 'multi_high_lower'], [1, -1]),
173  (['weightfile_low.xml', 'weightfile_high.xml'], ['multi_low_lower', 'multi_high_higher'], [-1, 1]),
174  ]:
175  extra_info_names = [prefix+x for x in extra_info_names]
176  identifiers = [prefix+x for x in identifiers]
177  path.add_module(
178  'MVAMultipleExperts',
179  listNames=['pi+:test'],
180  extraInfoNames=extra_info_names,
181  identifiers=identifiers,
182  overwriteExistingExtraInfo=overwrite_options)
183  path.add_module('MVAMultipleExperts', listNames=[], extraInfoNames=extra_info_names,
184  identifiers=identifiers, overwriteExistingExtraInfo=overwrite_options)
185 
186  path.add_module(MVAExtraInfoChecker())
187 
188  with tempfile.TemporaryDirectory() as tempdir:
189  for name, val in [('weightfile_low.xml', 0.0),
190  ('weightfile_mid.xml', 0.5),
191  ('weightfile_high.xml', 1.0)]:
192  with open(name, "w") as f:
193  f.write(create_weightfile(name, val))
194 
195  for name, val in [('multiclass_weightfile_low.xml', 0.0),
196  ('multiclass_weightfile_mid.xml', 0.5),
197  ('multiclass_weightfile_high.xml', 1.0)]:
198  with open(name, "w") as f:
199  f.write(create_multiclass_weightfile(name, val))
200 
201  basf2.process(path, 10)
a (simplified) python wrapper for StoreObjPtr.
Definition: PyStoreObj.h:67
def require_file(filename, data_type="", py_case=None)
Definition: __init__.py:54