Belle II Software development
lightgbm.py
1
2import numpy as np
3import lightgbm as lgb
4import os
5import tempfile
6
7
8class State(object):
9 """
10 LGBM state
11 """
12
13 def __init__(self, bst=None, params=None, X_valid=None, y_valid=None, path='LGBM.txt', trainFraction=0.8, num_round=100):
14 """ Constructor of the state object """
15
16 self.X_valid = X_valid
17
18 self.y_valid = y_valid
19
20 self.path = path
21
22 self.params = params
23
24 self.bst = bst
25
26 self.trainFraction = trainFraction
27
28
29def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
30 """
31 Create and return a state object containing the model and other necessary functions
32 """
33 # hyper parameters for lgbm
34 param = {'num_leaves': 31,
35 'objective': 'regression',
36 'learning_rate': 0.1,
37 'device_type': "cpu",
38 'deterministic': True,
39 'metric': 'auc',
40 'num_round': 100,
41 'max_bin': 255,
42 'boosting': "gbdt",
43 'num_threads': 1
44 }
45
46 if isinstance(parameters, dict):
47 param = {key: parameters[key] if key in param else value for key, value in param.items()}
48 state = State(params=param, path=str(parameters['path']), trainFraction=float(
49 parameters['trainFraction'])) # ,stop_round = int(parameters['stop_round']))
50 return state
51
52
53def begin_fit(state, Xtest, Stest, ytest, wtest, nBatches):
54 """
55 Begin fit, do nothing
56 """
57 return state
58
59
60def feature_importance(state):
61 """
62 Return a list containing the feature importances
63 """
64 return state.bst.feature_importance('gain').tolist()
65
66
67def partial_fit(state, X, S, y, w, epoch, batch):
68 """
69 Full fitting process:
70 1.randomly shuffle data
71 2.build LGBM dataset
72 3.run training
73 """
74 # randomly split
75 shuffled_indices = np.random.permutation(X.shape[0])
76 split_index = int(X.shape[0] * state.trainFraction)
77
78 state.train_set = lgb.Dataset(X[shuffled_indices[:split_index]],
79 label=y[shuffled_indices[:split_index]],
80 weight=w[shuffled_indices[:split_index]])
81
82 state.validation_set = state.train_set.create_valid(
83 X[shuffled_indices[split_index:]], label=y[shuffled_indices[split_index:]], weight=w[shuffled_indices[split_index:]])
84 # Do training
85 if (state.trainFraction != 1):
86 state.bst = lgb.train(state.params, state.train_set, valid_sets=[state.validation_set])
87 else:
88 state.bst = lgb.train(state.params, state.train_set)
89 del shuffled_indices
90 return True
91
92
93def end_fit(state):
94 """
95 End training process and write weights & hyperparameters into root file
96 """
97 with tempfile.TemporaryDirectory() as path:
98 state.bst.save_model(os.path.join(path, state.path))
99 file_names = [state.path]
100 files = []
101 for file_name in file_names:
102 with open(os.path.join(path, file_name), 'rb') as file:
103 files.append(file.read())
104 params = state.params
105 del state
106 return [file_names, files, params]
107
108
109def load(obj):
110 """
111 Load the trained model into state
112 """
113 with tempfile.TemporaryDirectory() as path:
114 file_names = obj[0]
115 for file_index, file_name in enumerate(file_names):
116 with open(f'{path}/{file_name}', 'w+b') as file:
117 file.write(bytes(obj[1][file_index]))
118
119 bst = lgb.Booster(model_file=os.path.join(path, file_names[0])) # init model
120 state = State()
121 state.bst = bst
122 state.params = obj[2]
123 return state
124
125
126def apply(state, X):
127 """
128 Apply model to data and make prediction
129 """
130 outputs = state.bst.predict(X)
131 return np.require(outputs, dtype=np.float32, requirements=['A', 'W', 'C', 'O'])
trainFraction
train fraction for dataset splitting
Definition: lightgbm.py:26
params
LightGBM Model parameter.
Definition: lightgbm.py:22
def __init__(self, bst=None, params=None, X_valid=None, y_valid=None, path='LGBM.txt', trainFraction=0.8, num_round=100)
Definition: lightgbm.py:13