Belle II Software  release-06-02-00
tensorflow.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3 
4 
11 
12 import numpy as np
13 import sys
14 import os
15 import tempfile
16 
17 
18 class State(object):
19  """
20  Tensorflow state
21  """
22 
23  def __init__(self, x=None, y=None, activation=None, cost=None, optimizer=None, session=None, **kwargs):
24  """ Constructor of the state object """
25 
26  self.xx = x
27 
28  self.yy = y
29 
30  self.activationactivation = activation
31 
32  self.costcost = cost
33 
34  self.optimizeroptimizer = optimizer
35 
36  self.sessionsession = session
37 
38  self.collection_keyscollection_keys = ['x', 'y', 'activation', 'cost', 'optimizer']
39 
40  # other possible things to save into a tensorflow collection
41  for key, value in kwargs.items():
42  self.collection_keyscollection_keys.append(key)
43  setattr(self, key, value)
44 
45  def add_to_collection(self):
46  """ Add the stored members to the current tensorflow collection """
47  try:
48  import tensorflow as tf
49  except ImportError:
50  print("Please install tensorflow: pip3 install tensorflow")
51  sys.exit(1)
52 
53  for key in self.collection_keyscollection_keys:
54  tf.add_to_collection(key, getattr(self, key))
55 
56  return self.collection_keyscollection_keys
57 
58  def get_from_collection(self, collection_keys=None):
59  """ Get members from the current tensorflow collection """
60  try:
61  import tensorflow as tf
62  except ImportError:
63  print("Please install tensorflow: pip3 install tensorflow")
64  sys.exit(1)
65 
66  if collection_keys is not None:
67  self.collection_keyscollection_keys = collection_keys
68 
69  for key in self.collection_keyscollection_keys:
70  setattr(self, key, tf.get_collection(key)[0])
71 
72 
73 def feature_importance(state):
74  """
75  Return a list containing the feature importances
76  """
77  return []
78 
79 
80 def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
81  """
82  Return default tensorflow model
83  """
84  try:
85  import tensorflow as tf
86  except ImportError:
87  print("Please install tensorflow: pip3 install tensorflow")
88  sys.exit(1)
89 
90  x = tf.placeholder("float", [None, number_of_features])
91  y = tf.placeholder("float", [None, 1])
92  w = tf.placeholder("float", [None, 1])
93  W = tf.Variable(tf.zeros([number_of_features, 1]))
94  b = tf.Variable(tf.zeros([1]))
95 
96  x_clean = tf.select(tf.is_nan(x), tf.ones_like(x) * 0., x)
97  activation = tf.nn.sigmoid(tf.matmul(x_clean, W) + b)
98 
99  epsilon = 1e-5
100  cost = -tf.reduce_sum(y * w * tf.log(activation + epsilon) + (1 - y) * w * tf.log(1 - activation + epsilon)) / tf.reduce_sum(w)
101 
102  learning_rate = 0.001
103  global_step = tf.Variable(0, name='global_step', trainable=False)
104  optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost, global_step=global_step)
105 
106  init = tf.global_variables_initializer()
107 
108  config = tf.ConfigProto()
109  config.gpu_options.allow_growth = True
110  session = tf.Session(config=config)
111  session.run(init)
112 
113  state = State(x, y, activation, cost, optimizer, session)
114  state.w = w
115  return state
116 
117 
118 def load(obj):
119  """
120  Load Tensorflow estimator into state
121  """
122  try:
123  import tensorflow as tf
124  except ImportError:
125  print("Please install tensorflow: pip3 install tensorflow")
126  sys.exit(1)
127 
128  tf.reset_default_graph()
129  config = tf.ConfigProto()
130  config.gpu_options.allow_growth = True
131  session = tf.Session(config=config)
132  saver = tf.train.import_meta_graph(obj[0])
133  with tempfile.TemporaryDirectory() as path:
134  with open(os.path.join(path, obj[1] + '.data-00000-of-00001'), 'w+b') as file1, open(
135  os.path.join(path, obj[1] + '.index'), 'w+b') as file2:
136  file1.write(bytes(obj[2]))
137  file2.write(bytes(obj[3]))
138  tf.train.update_checkpoint_state(path, obj[1])
139  saver.restore(session, os.path.join(path, obj[1]))
140  state = State(session=session)
141  if len(obj) > 4:
142  state.get_from_collection(obj[4])
143  for i, extra in enumerate(obj[5:]):
144  setattr(state, 'extra_{}'.format(i), extra)
145  else:
146  state.get_from_collection()
147 
148  return state
149 
150 
151 def apply(state, X):
152  """
153  Apply estimator to passed data.
154  """
155  r = state.session.run(state.activation, feed_dict={state.x: X}).flatten()
156  return np.require(r, dtype=np.float32, requirements=['A', 'W', 'C', 'O'])
157 
158 
159 def begin_fit(state, Xtest, Stest, ytest, wtest):
160  """
161  Returns just the state object
162  """
163  return state
164 
165 
166 def partial_fit(state, X, S, y, w, epoch):
167  """
168  Pass received data to tensorflow session
169  """
170  state.session.run(state.optimizer, feed_dict={state.x: X, state.y: y, state.w: w})
171  avg_cost = state.session.run(state.cost, feed_dict={state.x: X, state.y: y, state.w: w})
172  if epoch % 1000 == 0:
173  print("Epoch:", '%04d' % (epoch), "cost=", "{:.9f}".format(avg_cost))
174  if epoch == 100000:
175  return False
176  return True
177 
178 
179 def end_fit(state):
180  """
181  Store tensorflow session in a graph
182  """
183  try:
184  import tensorflow as tf
185  except ImportError:
186  print("Please install tensorflow: pip3 install tensorflow")
187  sys.exit(1)
188 
189  keys = state.add_to_collection()
190  saver = tf.train.Saver()
191  with tempfile.TemporaryDirectory() as path:
192  filename = saver.save(state.session, os.path.join(path, 'mymodel'))
193  with open(filename + str('.data-00000-of-00001'), 'rb') as file1, open(filename + str('.index'), 'rb') as file2:
194  data1 = file1.read()
195  data2 = file2.read()
196  meta_graph = saver.export_meta_graph()
197  del state
198  return [meta_graph, os.path.basename(filename), data1, data2, keys]
def get_from_collection(self, collection_keys=None)
Definition: tensorflow.py:58
collection_keys
array to save keys for collection
Definition: tensorflow.py:38
optimizer
optimizer used to minimize cost function
Definition: tensorflow.py:34
def __init__(self, x=None, y=None, activation=None, cost=None, optimizer=None, session=None, **kwargs)
Definition: tensorflow.py:23