Belle II Software  release-05-02-19
regression_in_python.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3 import basf2_mva
4 import pandas as pd
5 from root_pandas import to_root, read_root
6 import numpy as np
7 from matplotlib import pyplot as plt
8 
9 
10 def train_mva_method(file_name):
11  weight_file = "weightfile.root"
12 
13  general_options = basf2_mva.GeneralOptions()
14  general_options.m_datafiles = basf2_mva.vector(file_name)
15  general_options.m_treename = "tree"
16  general_options.m_identifier = weight_file
17  general_options.m_variables = basf2_mva.vector('A', 'B')
18  general_options.m_target_variable = 'C'
19 
20  regression_fastbdt_options = basf2_mva.RegressionFastBDTOptions()
21  # You can set the regression specific settings here
22  regression_fastbdt_options.setMaximalBinNumber(20)
23  # or the options specific to the base classifier
24  # base_options = regression_fastbdt_options.getBaseClassifierOptions()
25  # ...
26 
27  basf2_mva.teacher(general_options, regression_fastbdt_options)
28 
29  return weight_file
30 
31 
32 def create_random_data():
33  file_name = "data.root"
34 
35  # We generate a very simple example dataset: A and B are random and the target C is the average of A and B
36  df = pd.DataFrame({"A": np.random.rand(1000), "B": np.random.rand(1000)})
37  df["C"] = (df.A + df.B) / 2
38 
39  to_root(df, file_name, store_index=False, key="tree")
40  return file_name
41 
42 
43 def apply_expert(file_name, weight_file):
44  output_file = 'expert.root'
45  basf2_mva.expert(basf2_mva.vector(weight_file), basf2_mva.vector(file_name), 'tree', output_file)
46  return output_file
47 
48 
49 def create_plot(expert_file):
50  df = read_root(expert_file)
51  df.plot.scatter("weightfile__ptroot_C", "weightfile__ptroot", ax=plt.gca())
52  plt.xlabel("Correct")
53  plt.ylabel("Output")
54  plt.savefig("result.pdf")
55 
56 
57 if __name__ == "__main__":
58  from basf2 import conditions
59  # NOTE: do not use testing payloads in production! Any results obtained like this WILL NOT BE PUBLISHED
60  conditions.testing_payloads = [
61  'localdb/database.txt'
62  ]
63 
64  # Lets create some random data
65  file_name = create_random_data()
66 
67  # Train a regression MVA method
68  weight_file = train_mva_method(file_name)
69 
70  # Apply the trained methods on data
71  expert_file = apply_expert(file_name, weight_file)
72 
73  # And generate an example plot
74  create_plot(expert_file)