Belle II Software  release-06-00-14
regression_in_python.py
1 #!/usr/bin/env python3
2 
3 
10 
11 import basf2_mva
12 import pandas as pd
13 from root_pandas import to_root, read_root
14 import numpy as np
15 from matplotlib import pyplot as plt
16 
17 
18 def train_mva_method(file_name):
19  weight_file = "weightfile.root"
20 
21  general_options = basf2_mva.GeneralOptions()
22  general_options.m_datafiles = basf2_mva.vector(file_name)
23  general_options.m_treename = "tree"
24  general_options.m_identifier = weight_file
25  general_options.m_variables = basf2_mva.vector('A', 'B')
26  general_options.m_target_variable = 'C'
27 
28  regression_fastbdt_options = basf2_mva.RegressionFastBDTOptions()
29  # You can set the regression specific settings here
30  regression_fastbdt_options.setMaximalBinNumber(20)
31  # or the options specific to the base classifier
32  # base_options = regression_fastbdt_options.getBaseClassifierOptions()
33  # ...
34 
35  basf2_mva.teacher(general_options, regression_fastbdt_options)
36 
37  return weight_file
38 
39 
40 def create_random_data():
41  file_name = "data.root"
42 
43  # We generate a very simple example dataset: A and B are random and the target C is the average of A and B
44  df = pd.DataFrame({"A": np.random.rand(1000), "B": np.random.rand(1000)})
45  df["C"] = (df.A + df.B) / 2
46 
47  to_root(df, file_name, store_index=False, key="tree")
48  return file_name
49 
50 
51 def apply_expert(file_name, weight_file):
52  output_file = 'expert.root'
53  basf2_mva.expert(basf2_mva.vector(weight_file), basf2_mva.vector(file_name), 'tree', output_file)
54  return output_file
55 
56 
57 def create_plot(expert_file):
58  df = read_root(expert_file)
59  df.plot.scatter("weightfile__ptroot_C", "weightfile__ptroot", ax=plt.gca())
60  plt.xlabel("Correct")
61  plt.ylabel("Output")
62  plt.savefig("result.pdf")
63 
64 
65 if __name__ == "__main__":
66  from basf2 import conditions
67  # NOTE: do not use testing payloads in production! Any results obtained like this WILL NOT BE PUBLISHED
68  conditions.testing_payloads = [
69  'localdb/database.txt'
70  ]
71 
72  # Lets create some random data
73  file_name = create_random_data()
74 
75  # Train a regression MVA method
76  weight_file = train_mva_method(file_name)
77 
78  # Apply the trained methods on data
79  expert_file = apply_expert(file_name, weight_file)
80 
81  # And generate an example plot
82  create_plot(expert_file)