Belle II Software light-2406-ragdoll
hep_ml_ugboost.py
1#!/usr/bin/env python3
2
3
10
11import basf2_mva
12import hep_ml
13import hep_ml.losses
14import hep_ml.gradientboosting
16
17
18def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
19 """
20 Create hep_ml classifier and store it in a State object.
21 """
22 train_features = list(range(number_of_features))
23 uniform_features = [number_of_features + i for i in range(number_of_spectators)]
24
25 loss = hep_ml.losses.AdaLossFunction()
26 if parameters is not None and 'uniform_rate' in parameters:
27 loss = hep_ml.losses.BinFlatnessLossFunction(uniform_features=uniform_features, uniform_label=[0, 1],
28 fl_coefficient=parameters['uniform_rate'])
29 clf = hep_ml.gradientboosting.UGradientBoostingClassifier(loss=loss, n_estimators=100, subsample=0.5,
30 max_depth=5, train_features=train_features)
31 return State(clf)
32
33
34if __name__ == "__main__":
35 from basf2 import conditions, find_file
36 # NOTE: do not use testing payloads in production! Any results obtained like this WILL NOT BE PUBLISHED
37 conditions.testing_payloads = [
38 'localdb/database.txt'
39 ]
40
41 variables = ['p', 'pt', 'pz', 'phi',
42 'daughter(0, p)', 'daughter(0, pz)', 'daughter(0, pt)', 'daughter(0, phi)',
43 'daughter(1, p)', 'daughter(1, pz)', 'daughter(1, pt)', 'daughter(1, phi)',
44 'daughter(2, p)', 'daughter(2, pz)', 'daughter(2, pt)', 'daughter(2, phi)',
45 'chiProb', 'dr', 'dz', 'dphi',
46 'daughter(0, dr)', 'daughter(1, dr)', 'daughter(0, dz)', 'daughter(1, dz)',
47 'daughter(0, dphi)', 'daughter(1, dphi)',
48 'daughter(0, chiProb)', 'daughter(1, chiProb)', 'daughter(2, chiProb)',
49 'daughter(0, kaonID)', 'daughter(0, pionID)', 'daughter(1, kaonID)', 'daughter(1, pionID)',
50 'daughterAngle(0, 1)', 'daughterAngle(0, 2)', 'daughterAngle(1, 2)',
51 'daughter(2, daughter(0, E))', 'daughter(2, daughter(1, E))',
52 'daughter(2, daughter(0, clusterTiming))', 'daughter(2, daughter(1, clusterTiming))',
53 'daughter(2, daughter(0, clusterE9E25))', 'daughter(2, daughter(1, clusterE9E25))',
54 'daughter(2, daughter(0, minC2TDist))', 'daughter(2, daughter(1, minC2TDist))',
55 'M']
56
57 variables2 = ['p', 'pt', 'pz', 'phi',
58 'chiProb', 'dr', 'dz', 'dphi',
59 'daughter(2, chiProb)',
60 'daughter(0, kaonID)', 'daughter(0, pionID)', 'daughter(1, kaonID)', 'daughter(1, pionID)',
61 'daughter(2, daughter(0, E))', 'daughter(2, daughter(1, E))',
62 'daughter(2, daughter(0, clusterTiming))', 'daughter(2, daughter(1, clusterTiming))',
63 'daughter(2, daughter(0, clusterE9E25))', 'daughter(2, daughter(1, clusterE9E25))',
64 'daughter(2, daughter(0, minC2TDist))', 'daughter(2, daughter(1, minC2TDist))']
65
66 train_file = find_file("mva/train_D0toKpipi.root", "examples")
67 training_data = basf2_mva.vector(train_file)
68
69 general_options = basf2_mva.GeneralOptions()
70 general_options.m_datafiles = training_data
71 general_options.m_treename = "tree"
72 general_options.m_variables = basf2_mva.vector(*variables)
73 # Spectators are the variables for which the selection should be uniform
74 general_options.m_spectators = basf2_mva.vector('daughterInvM(0, 1)', 'daughterInvM(0, 2)')
75 general_options.m_target_variable = "isSignal"
76 general_options.m_identifier = "hep_ml_baseline"
77
78 specific_options = basf2_mva.PythonOptions()
79 specific_options.m_framework = 'hep_ml'
80 specific_options.m_steering_file = 'mva/examples/orthogonal_discriminators/hep_ml_ugboost.py'
81 basf2_mva.teacher(general_options, specific_options)
82
83 # Set the parameters of the uBoostClassifier
84 import json
85 specific_options.m_config = json.dumps({'uniform_rate': 10.0})
86 general_options.m_identifier = "hep_ml"
87 basf2_mva.teacher(general_options, specific_options)
88
89 specific_options = basf2_mva.PythonOptions()
90 general_options.m_identifier = "hep_ml_feature_drop"
91 specific_options.m_framework = 'hep_ml'
92 specific_options.m_steering_file = 'mva/examples/orthogonal_discriminators/hep_ml_ugboost.py'
93 general_options.m_variables = basf2_mva.vector(*variables2)
94 basf2_mva.teacher(general_options, specific_options)