Belle II Software  release-05-02-19
root_returnvalues.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3 
4 """
5 Check that ROOT returns the correct values (ints, floats) when calling a C++
6 function from python
7 """
8 
9 import sys
10 import ROOT
11 import numpy as np
12 
13 testclass = [
14  "#include <cstdint>",
15  "#include <limits>",
16  "struct ReturnValueTests {"
17 ]
18 
19 results = []
20 # what integer sizes to check
21 integral_types = [8, 16, 32, 64]
22 # which floating types to check
23 floating_types = ["float", "double"]
24 # nan and inf need to be expressed differently so make a mapping
25 floating_expressions = {
26  -np.inf: "-std::numeric_limits<{0}>::infinity()",
27  np.inf: "std::numeric_limits<{0}>::infinity()",
28  np.nan: "std::numeric_limits<{0}>::quiet_NaN()",
29 }
30 
31 for width in integral_types:
32  # check signed types, we want min, max -1, 0 and 1
33  dtype = np.dtype("int%d" % width)
34  for i, value in enumerate([-2**(width-1), -1, 0, 1, 2**(width-1) - 1]):
35  # create a unique member name
36  func_name = "int{0}test{1}".format(width, i)
37  # do we need a prefix to the literal?
38  prefix = "ull" if width > 32 else ""
39  # append it to the test class
40  testclass.append("int{0}_t {1}() const {{ return {2}{3}; }}".format(width, func_name, value, prefix))
41  testclass.append("int{0}_t& ref_{1}() {{ static int{0}_t val = {2}{3}; return val; }}".format(
42  width, func_name, value, prefix))
43  # and remember name and result for checking
44  results.append([func_name, value, dtype])
45  results.append(['ref_' + func_name, value, dtype])
46 
47  # check unsigned types, just 0, 1, and max
48  dtype = np.dtype("uint%d" % width)
49  for i, value in enumerate([0, 1, 2**(width) - 1]):
50  # create a unique member name
51  func_name = "uint{0}test{1}".format(width, i)
52  # do we need a prefix to the literal?
53  prefix = "ull" if width > 32 else ""
54  # append it to the test class
55  testclass.append("uint{0}_t {1}() const {{ return {2}ull; }}".format(width, func_name, value, prefix))
56  testclass.append("uint{0}_t& ref_{1}() {{ static uint{0}_t val = {2}{3}; return val; }}".format(
57  width, func_name, value, prefix))
58  # and remember name and result for checking
59  results.append([func_name, value, dtype])
60  results.append(['ref_' + func_name, value, dtype])
61 
62 # now add floating types
63 for t in floating_types:
64  # and exploit that numpy offers a numeric_limits equivalent
65  info = np.finfo(t[0])
66 
67  # check some values
68  for i, value in enumerate([-np.inf, info.min, -1, 0, info.tiny, info.eps, 1, info.max, np.inf, np.nan]):
69  func_name = "{0}test{1}".format(t, i)
70  # see if we can just have the literal or if we need a mapping
71  expression = repr(value) if value not in floating_expressions else floating_expressions[value].format(t)
72  testclass.append("{0} {1}() const {{ return {2}; }}".format(t, func_name, expression))
73  testclass.append("{0}& ref_{1}() {{ static {0} val = {2}; return val; }}".format(t, func_name, expression))
74  results.append([func_name, value, info.dtype])
75  results.append(['ref_' + func_name, value, info.dtype])
76 
77 
78 # compile the test class
79 testclass.append("};")
80 ROOT.gROOT.ProcessLine("\n".join(testclass))
81 # get an instance
82 tests = ROOT.ReturnValueTests()
83 # and do all the checks
84 failures = 0
85 for func, value, dtype in results:
86  ret = getattr(tests, func)()
87  # char is odd and returns strings instead of int. The simple way of using
88  # ord(string) will loose the sign so we have to convert it to the correct
89  # type using numpy but for that we have to convert the string to 8bit.
90  if isinstance(ret, str):
91  ret = np.fromstring(ret.encode("latin1"), dtype)[0]
92 
93  # print the test
94  print("check %s(): %r == %r: " % (func, value, ret), end="")
95  # nan needs special care because it cannot be compared to itself
96  if np.isnan(value):
97  passed = np.isnan(ret)
98  else:
99  passed = (ret == value)
100 
101  print("\033[32mOK\033[0m" if passed else "\033[31mFAIL\033[0m")
102  if not passed:
103  failures += 1
104 
105 sys.exit(failures)