Belle II Software  release-05-02-19
tensorflow.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3 
4 # Thomas Keck 2016
5 
6 import numpy as np
7 import sys
8 import os
9 import tempfile
10 
11 
12 class State(object):
13  """
14  Tensorflow state
15  """
16 
17  def __init__(self, x=None, y=None, activation=None, cost=None, optimizer=None, session=None, **kwargs):
18  """ Constructor of the state object """
19 
20  self.x = x
21 
22  self.y = y
23 
24  self.activation = activation
25 
26  self.cost = cost
27 
28  self.optimizer = optimizer
29 
30  self.session = session
31 
32  self.collection_keys = ['x', 'y', 'activation', 'cost', 'optimizer']
33 
34  # other possible things to save into a tensorflow collection
35  for key, value in kwargs.items():
36  self.collection_keys.append(key)
37  setattr(self, key, value)
38 
39  def add_to_collection(self):
40  """ Add the stored members to the current tensorflow collection """
41  try:
42  import tensorflow as tf
43  except ImportError:
44  print("Please install tensorflow: pip3 install tensorflow")
45  sys.exit(1)
46 
47  for key in self.collection_keys:
48  tf.add_to_collection(key, getattr(self, key))
49 
50  return self.collection_keys
51 
52  def get_from_collection(self, collection_keys=None):
53  """ Get members from the current tensorflow collection """
54  try:
55  import tensorflow as tf
56  except ImportError:
57  print("Please install tensorflow: pip3 install tensorflow")
58  sys.exit(1)
59 
60  if collection_keys is not None:
61  self.collection_keys = collection_keys
62 
63  for key in self.collection_keys:
64  setattr(self, key, tf.get_collection(key)[0])
65 
66 
67 def feature_importance(state):
68  """
69  Return a list containing the feature importances
70  """
71  return []
72 
73 
74 def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
75  """
76  Return default tensorflow model
77  """
78  try:
79  import tensorflow as tf
80  except ImportError:
81  print("Please install tensorflow: pip3 install tensorflow")
82  sys.exit(1)
83 
84  x = tf.placeholder("float", [None, number_of_features])
85  y = tf.placeholder("float", [None, 1])
86  w = tf.placeholder("float", [None, 1])
87  W = tf.Variable(tf.zeros([number_of_features, 1]))
88  b = tf.Variable(tf.zeros([1]))
89 
90  x_clean = tf.select(tf.is_nan(x), tf.ones_like(x) * 0., x)
91  activation = tf.nn.sigmoid(tf.matmul(x_clean, W) + b)
92 
93  epsilon = 1e-5
94  cost = -tf.reduce_sum(y * w * tf.log(activation + epsilon) + (1 - y) * w * tf.log(1 - activation + epsilon)) / tf.reduce_sum(w)
95 
96  learning_rate = 0.001
97  global_step = tf.Variable(0, name='global_step', trainable=False)
98  optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost, global_step=global_step)
99 
100  init = tf.global_variables_initializer()
101 
102  config = tf.ConfigProto()
103  config.gpu_options.allow_growth = True
104  session = tf.Session(config=config)
105  session.run(init)
106 
107  state = State(x, y, activation, cost, optimizer, session)
108  state.w = w
109  return state
110 
111 
112 def load(obj):
113  """
114  Load Tensorflow estimator into state
115  """
116  try:
117  import tensorflow as tf
118  except ImportError:
119  print("Please install tensorflow: pip3 install tensorflow")
120  sys.exit(1)
121 
122  tf.reset_default_graph()
123  config = tf.ConfigProto()
124  config.gpu_options.allow_growth = True
125  session = tf.Session(config=config)
126  saver = tf.train.import_meta_graph(obj[0])
127  with tempfile.TemporaryDirectory() as path:
128  with open(os.path.join(path, obj[1] + '.data-00000-of-00001'), 'w+b') as file1, open(
129  os.path.join(path, obj[1] + '.index'), 'w+b') as file2:
130  file1.write(bytes(obj[2]))
131  file2.write(bytes(obj[3]))
132  tf.train.update_checkpoint_state(path, obj[1])
133  saver.restore(session, os.path.join(path, obj[1]))
134  state = State(session=session)
135  if len(obj) > 4:
136  state.get_from_collection(obj[4])
137  for i, extra in enumerate(obj[5:]):
138  setattr(state, 'extra_{}'.format(i), extra)
139  else:
140  state.get_from_collection()
141 
142  return state
143 
144 
145 def apply(state, X):
146  """
147  Apply estimator to passed data.
148  """
149  r = state.session.run(state.activation, feed_dict={state.x: X}).flatten()
150  return np.require(r, dtype=np.float32, requirements=['A', 'W', 'C', 'O'])
151 
152 
153 def begin_fit(state, Xtest, Stest, ytest, wtest):
154  """
155  Returns just the state object
156  """
157  return state
158 
159 
160 def partial_fit(state, X, S, y, w, epoch):
161  """
162  Pass received data to tensorflow session
163  """
164  state.session.run(state.optimizer, feed_dict={state.x: X, state.y: y, state.w: w})
165  avg_cost = state.session.run(state.cost, feed_dict={state.x: X, state.y: y, state.w: w})
166  if epoch % 1000 == 0:
167  print("Epoch:", '%04d' % (epoch), "cost=", "{:.9f}".format(avg_cost))
168  if epoch == 100000:
169  return False
170  return True
171 
172 
173 def end_fit(state):
174  """
175  Store tensorflow session in a graph
176  """
177  try:
178  import tensorflow as tf
179  except ImportError:
180  print("Please install tensorflow: pip3 install tensorflow")
181  sys.exit(1)
182 
183  keys = state.add_to_collection()
184  saver = tf.train.Saver()
185  with tempfile.TemporaryDirectory() as path:
186  filename = saver.save(state.session, os.path.join(path, 'mymodel'))
187  with open(filename + str('.data-00000-of-00001'), 'rb') as file1, open(filename + str('.index'), 'rb') as file2:
188  data1 = file1.read()
189  data2 = file2.read()
190  meta_graph = saver.export_meta_graph()
191  del state
192  return [meta_graph, os.path.basename(filename), data1, data2, keys]
basf2_mva_python_interface.tensorflow.State.session
session
tensorflow session
Definition: tensorflow.py:30
basf2_mva_python_interface.tensorflow.State.activation
activation
activation function
Definition: tensorflow.py:24
basf2_mva_python_interface.tensorflow.State.optimizer
optimizer
optimizer used to minimize cost function
Definition: tensorflow.py:28
basf2_mva_python_interface.tensorflow.State.collection_keys
collection_keys
array to save keys for collection
Definition: tensorflow.py:32
basf2_mva_python_interface.tensorflow.State.x
x
feature matrix placeholder
Definition: tensorflow.py:20
basf2_mva_python_interface.tensorflow.State.add_to_collection
def add_to_collection(self)
Definition: tensorflow.py:39
basf2_mva_python_interface.tensorflow.State.__init__
def __init__(self, x=None, y=None, activation=None, cost=None, optimizer=None, session=None, **kwargs)
Definition: tensorflow.py:17
basf2_mva_python_interface.tensorflow.State
Definition: tensorflow.py:12
basf2_mva_python_interface.tensorflow.State.y
y
target placeholder
Definition: tensorflow.py:22
basf2_mva_python_interface.tensorflow.State.get_from_collection
def get_from_collection(self, collection_keys=None)
Definition: tensorflow.py:52
basf2_mva_python_interface.tensorflow.State.cost
cost
cost function
Definition: tensorflow.py:26