Belle II Software light-2406-ragdoll
regression_in_python.py
1#!/usr/bin/env python3
2
3
10
11import basf2_mva
12import pandas as pd
13import uproot
14import numpy as np
15from matplotlib import pyplot as plt
16
17
18def train_mva_method(file_name):
19
20 weight_file = "weightfile.root"
21
22 general_options = basf2_mva.GeneralOptions()
23 general_options.m_datafiles = basf2_mva.vector(file_name)
24 general_options.m_treename = "tree"
25 general_options.m_identifier = weight_file
26 general_options.m_variables = basf2_mva.vector('A', 'B')
27 general_options.m_target_variable = 'C'
28
29 regression_fastbdt_options = basf2_mva.RegressionFastBDTOptions()
30 # You can set the regression specific settings here
31 regression_fastbdt_options.setMaximalBinNumber(20)
32 # or the options specific to the base classifier
33 # base_options = regression_fastbdt_options.getBaseClassifierOptions()
34 # ...
35
36 basf2_mva.teacher(general_options, regression_fastbdt_options)
37
38 return weight_file
39
40
41def create_random_data():
42 file_name = "data.root"
43
44 # We generate a very simple example dataset: A and B are random and the target C is the average of A and B
45 df = pd.DataFrame({"A": np.random.rand(1000), "B": np.random.rand(1000)})
46 df["C"] = (df.A + df.B) / 2
47
48 with uproot.recreate(file_name) as outfile:
49 outfile['tree'] = df
50 return file_name
51
52
53def apply_expert(file_name, weight_file):
54 output_file = 'expert.root'
55 basf2_mva.expert(basf2_mva.vector(weight_file), basf2_mva.vector(file_name), 'tree', output_file)
56 return output_file
57
58
59def create_plot(expert_file):
60 df = uproot.open(expert_file)["variables"].arrays(library="pd")
61 df.plot.scatter("weightfile__ptroot_C", "weightfile__ptroot", ax=plt.gca())
62 plt.xlabel("Correct")
63 plt.ylabel("Output")
64 plt.savefig("result.pdf")
65
66
67if __name__ == "__main__":
68 from basf2 import conditions
69 # NOTE: do not use testing payloads in production! Any results obtained like this WILL NOT BE PUBLISHED
70 conditions.testing_payloads = [
71 'localdb/database.txt'
72 ]
73
74 # Lets create some random data
75 file_name = create_random_data()
76
77 # Train a regression MVA method
78 weight_file = train_mva_method(file_name)
79
80 # Apply the trained methods on data
81 expert_file = apply_expert(file_name, weight_file)
82
83 # And generate an example plot
84 create_plot(expert_file)