15 from matplotlib
import pyplot
as plt
18 def train_mva_method(file_name):
20 weight_file =
"weightfile.root"
22 general_options = basf2_mva.GeneralOptions()
23 general_options.m_datafiles = basf2_mva.vector(file_name)
24 general_options.m_treename =
"tree"
25 general_options.m_identifier = weight_file
26 general_options.m_variables = basf2_mva.vector(
'A',
'B')
27 general_options.m_target_variable =
'C'
29 regression_fastbdt_options = basf2_mva.RegressionFastBDTOptions()
31 regression_fastbdt_options.setMaximalBinNumber(20)
36 basf2_mva.teacher(general_options, regression_fastbdt_options)
41 def create_random_data():
42 file_name =
"data.root"
45 df = pd.DataFrame({
"A": np.random.rand(1000),
"B": np.random.rand(1000)})
46 df[
"C"] = (df.A + df.B) / 2
48 with uproot.recreate(file_name)
as outfile:
53 def apply_expert(file_name, weight_file):
54 output_file =
'expert.root'
55 basf2_mva.expert(basf2_mva.vector(weight_file), basf2_mva.vector(file_name),
'tree', output_file)
59 def create_plot(expert_file):
60 df = uproot.open(expert_file)[
"variables"].arrays(library=
"pd")
61 df.plot.scatter(
"weightfile__ptroot_C",
"weightfile__ptroot", ax=plt.gca())
64 plt.savefig(
"result.pdf")
67 if __name__ ==
"__main__":
68 from basf2
import conditions
70 conditions.testing_payloads = [
71 'localdb/database.txt'
75 file_name = create_random_data()
78 weight_file = train_mva_method(file_name)
81 expert_file = apply_expert(file_name, weight_file)
84 create_plot(expert_file)