Belle II Software  release-05-01-25
tools.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3 
4 """ Tools collection
5 
6 In the tools collection all plotting tools are gathered.
7 
8 """
9 __author__ = 'swehle'
10 
11 from . import transform
12 from .settings import create_figure
13 
14 import matplotlib.pyplot as plt
15 import pandas as pd
16 import numpy as np
17 from scipy.stats import chisqprob
18 import math
19 
20 
21 def set_axis_label_range(ax, new_start, new_end, n_labels=5, axis=1, to_flat=None):
22  """
23  Set the labels to a different range
24  :param ax: axis object
25  :param new_start: New start value
26  :param new_end: New end value
27  :param n_labels: N labels
28  :param axis: default is x axis 1
29  :param to_flat: Flat transformation object for getting non linear values on the axis
30  """
31 
32  start, end = ax.get_xlim()
33  # print start, end
34  label_position = np.append(np.arange(start, end, (end - start) / float(n_labels - 1)), end)
35 
36  # Wrong linear interploation
37  new_labels = np.append(np.arange(new_start, new_end, (new_end - new_start) / float(n_labels - 1)), new_end)
38 
39  # None linear 'correct' case using the CDF as reference
40  if to_flat is not None:
41  assert isinstance(to_flat, transform.ToFlat)
42  x_on_flat = np.linspace(0, 1, n_labels)
43  new_labels = []
44 
45  for x, i in zip(x_on_flat, list(range(0, n_labels))):
46  new_labels.append(to_flat.get_x(x))
47  new_labels[-1] = to_flat.max
48  new_labels[0] = to_flat.min
49 
50  if axis is 1:
51  ax.set_xticks(label_position)
52  ax.set_xticklabels(["%.2f" % i for i in new_labels])
53  else:
54  ax.set_yticks(label_position)
55  ax.set_yticklabels(["%.2f" % i for i in new_labels])
56 
57 
58 def draw_flat_correlation(x, y, ax=None, draw_label=True, width=5):
59  """
60  This function draws a flat correlation distribution.
61  Both x an y have to be equally sized and are transformed to a flat distribution.
62 
63  :param x: dist x, pandas Series
64  :param y: dist y, pandas Series
65  :param ax: axis object if drawn in a subplot
66  :param draw_label: draw the labels of the distribution (only works with pandas Series)
67  :param width: width of the plot, default 5
68  """
69 
70  not_on_axes = True if ax is None else False
71 
72  if ax is None:
73  fig, ax = create_figure(width=width, ratio=7 / 6.)
74 
75  assert isinstance(x, pd.Series or np.array), 'Argument of wrong type!'
76  assert isinstance(y, pd.Series or np.array), 'Argument of wrong type!'
77  x_val = x.values
78  y_val = y.values
79 
80  # Flat Distribution
81  tx = transform.ToFlat()
82  ty = transform.ToFlat()
83  tx.fit(x_val)
84  ty.fit(y_val)
85 
86  # bins and expected events
87  n_bins = transform.get_optimal_bin_size(min(len(x), len(y)))
88  n_bins = int(math.sqrt(n_bins) * 2)
89  nexp = len(x) / n_bins ** 2
90  nerr = math.sqrt(nexp)
91  a = np.histogram2d(tx.transform(x_val), ty.transform(y_val), bins=(n_bins, n_bins))
92 
93  # Transforming the matrix
94  a = np.array(a[0])
95  a = (a - nexp) / nerr
96 
97  # Draw the matrix
98  im = ax.imshow(a.T, interpolation='nearest', vmin=-5, vmax=5)
99  if not_on_axes:
100  print("Printing colorbar")
101  plt.colorbar(im, fraction=0.046, pad=0.04)
102  set_axis_label_range(ax, x.min(), x.max(), to_flat=tx)
103  set_axis_label_range(ax, y.min(), y.max(), axis=0, to_flat=ty)
104  else:
105  ax.set_xticklabels([])
106  ax.set_yticklabels([])
107 
108  if draw_label:
109  ax.set_xlabel(x.name)
110  ax.set_ylabel(y.name)
111 
112  # Calculate overall chi2 error for flat distribution
113  chi2 = 0
114  for i in range(0, n_bins):
115  for j in range(0, n_bins):
116  # a[i][j] = (a[i][j] - nexp) / nerr
117  chi2 += a[i][j] * a[i][j]
118 
119  proba = chisqprob(chi2, n_bins * n_bins - ((n_bins - 1) + (n_bins - 1) + 1))
120  if not_on_axes:
121  ax.set_title("Probability of flat hypothesis %.2f%%" % (proba * 100))
122  return im
123 
124 
125 class ProfilePlot():
126 
127  """ Basic Profile plot
128 
129  Creates the profile Histogram from x and y distrinbutions
130  It plots mean(y) in bins of x
131 
132  Attributes:
133  x_axis (array) : Binning in x
134  mean (array) : Mean of y in bin x
135  err (array) : Std of Mean y in bin x
136  label (string) : Matplotlib label for the plot
137  """
138 
139  def __init__(self, x, y, x_axis=None, n_bins=None, label=None):
140  """ init function
141  :param x: Distribution in x
142  :param y: Distribution in y
143  :param n_bins: (optional) n bins in x, is set automatically if not provided
144  :param x_axis: binning for the x-axis
145  :param label: Matplotlib label for the plot
146  """
147  if x_axis is None:
148  x_axis = transform.get_optimal_bin_size(len(x))
149  if n_bins is not None:
150  x_axis = n_bins
151 
152 
153  _, self.x_axis = np.histogram(x, x_axis)
154 
155 
156  self.mean = []
157 
158 
159  self.err = []
160 
161 
162  self.label = label
163 
164  # Calculating the Profile histogram
165  for last_x, next_x in zip(self.x_axis[:-1], self.x_axis[1:]):
166  bin_range = (x > last_x) & (x < next_x)
167  n_y_in_bin = len(y[bin_range])
168  if n_y_in_bin is 0:
169  self.mean.append(0)
170  self.err.append(0)
171  else:
172  self.mean.append(np.mean(y[bin_range]))
173  self.err.append(np.sqrt(np.var(y[bin_range]) / n_y_in_bin))
174 
175  def draw(self, color='black'):
176  """ Draw function
177  :param color: matplotlib color
178  """
179  bin_centers = (self.x_axis[1:] + self.x_axis[:-1]) / 2.0
180  plt.errorbar(bin_centers, self.mean, color=color, yerr=self.err,
181  linewidth=2, ecolor=color, label=self.label, fmt='.')
182 
183 
184 def draw_flat_corr_matrix(df, pdf=None, tight=False, col_numbers=False, labels=None, fontsize=18, size=12):
185  """
186  :param df: DataFrame of the input data
187  :param pdf: optional, file to save
188  :param tight: tight layout, be careful
189  :param col_numbers: switch between numbers or names for the columns
190  :param labels: optional, list of latex labels
191  :param fontsize: size of the labels
192  """
193  assert isinstance(df, pd.DataFrame), 'Argument of wrong type!'
194 
195  n_vars = np.shape(df)[1]
196 
197  if labels is None:
198  labels = df.columns
199 
200  fig, axes = plt.subplots(nrows=n_vars, ncols=n_vars, figsize=(size, size))
201  for i, row in zip(list(range(n_vars)), axes):
202  for j, ax in zip(list(range(n_vars)), row):
203  if i is j:
204  plt.sca(ax)
205  plt.hist(df.ix[:, i].values, transform.get_optimal_bin_size(len(df)), color="gray", histtype='step')
206  ax.set_yticklabels([])
207  set_axis_label_range(ax, df.ix[:, i].min(), df.ix[:, i].max(), n_labels=3)
208  else:
209  im = draw_flat_correlation(df.ix[:, i], df.ix[:, j], ax=ax, draw_label=False)
210 
211  if i is n_vars - 1 and j is not n_vars - 1:
212  plt.setp(ax.get_xticklabels(), visible=False)
213 
214  if i is n_vars - 1:
215  ax.xaxis.set_label_coords(0.5, -0.15)
216 
217  if tight:
218  plt.tight_layout()
219 
220  # Common outer label
221  for i, row in zip(list(range(n_vars)), axes):
222  for j, ax in zip(list(range(n_vars)), row):
223  if i == n_vars - 1:
224  if col_numbers:
225  ax.set_xlabel("%d" % j)
226  else:
227  ax.set_xlabel(labels[j], fontsize=fontsize)
228  if j == 0:
229  if col_numbers:
230  ax.set_ylabel("%d" % i)
231  else:
232  ax.set_ylabel(labels[i], fontsize=fontsize)
233 
234  if pdf is None:
235  # plt.show()
236  pass
237  else:
238  pdf.savefig()
239  plt.close()
240 
241 
242 def draw_fancy_correlation_matrix(df, pdf=None, tight=False, col_numbers=False, labels=None, fontsize=18, size=12):
243  """
244  Draws a colored correlation matrix with a profile plot overlay.
245 
246  :param df: DataFrame of the input data
247  :param pdf: optional, file to save
248  :param tight: tight layout, be carefult
249  :param col_numbers: swith bwtween numbers or names for the clumns
250  :param labels: optional, list of latex labels
251  :param fontsize: size of the labels
252  """
253 
254  import matplotlib
255 
256  assert isinstance(df, pd.DataFrame), 'Argument of wrong type!'
257 
258  n_vars = np.shape(df)[1]
259 
260  if labels is None:
261  labels = df.columns
262 
263  corr = df.corr().values
264  norm = matplotlib.colors.Normalize(vmin=-1, vmax=1)
265  color = plt.cm.jet
266  cma = plt.cm.ScalarMappable(norm=norm, cmap=color)
267 
268  fig, axes = plt.subplots(nrows=n_vars, ncols=n_vars, figsize=(size, size))
269  for i, row in zip(list(range(n_vars)), axes):
270  for j, ax in zip(list(range(n_vars)), row):
271  if i is j:
272  plt.sca(ax)
273  plt.hist(df.ix[:, i].values, transform.get_optimal_bin_size(len(df)), color="gray", histtype='step')
274  # plt.xlabel(df.columns[i] if isinstance(df.columns[i], basestring) else "%d" % df.columns[i])
275  ax.set_yticklabels([])
276  set_axis_label_range(ax, df.ix[:, i].min(), df.ix[:, i].max(), n_labels=3)
277  else:
278  plt.sca(ax)
279 
280  h = ProfilePlot(df.ix[:, i].values, df.ix[:, j].values, label='data', n_bins=10)
281  h.draw(color="white")
282 
283  x_middle = (plt.xlim()[1] + plt.xlim()[0]) / 2.
284  y_middle = (plt.ylim()[1] + plt.ylim()[0]) / 2.
285 
286  ax.text(x_middle, y_middle, "$%.3f$" % corr[i][j], fontsize=24, va='center', ha='center')
287 
288  ax.patch.set_facecolor(cma.to_rgba(corr[i][j]))
289 
290  ax.set_yticklabels([])
291  ax.set_xticklabels([])
292 
293  if i is n_vars - 1 and j is not n_vars - 1:
294  plt.setp(ax.get_xticklabels(), visible=False)
295 
296  if i is n_vars - 1:
297  ax.xaxis.set_label_coords(0.5, -0.15)
298 
299  if tight:
300  plt.tight_layout()
301 
302  # Common outer label
303  for i, row in zip(list(range(n_vars)), axes):
304  for j, ax in zip(list(range(n_vars)), row):
305  if i == n_vars - 1:
306  if col_numbers:
307  ax.set_xlabel("%d" % j)
308  else:
309  ax.set_xlabel(labels[j], fontsize=fontsize)
310  if j == 0:
311  if col_numbers:
312  ax.set_ylabel("%d" % i)
313  else:
314  ax.set_ylabel(labels[i], fontsize=fontsize)
315 
316  if pdf is None:
317  # plt.show()
318  pass
319  else:
320  pdf.savefig()
321  plt.close()
alignment.fancystuff.tools.ProfilePlot.mean
mean
Mean of y in bin x.
Definition: tools.py:156
alignment.fancystuff.tools.ProfilePlot.x_axis
x_axis
Binning in x.
Definition: tools.py:153
alignment.fancystuff.tools.ProfilePlot.err
err
Std of Mean y in bin x.
Definition: tools.py:159
alignment.fancystuff.tools.ProfilePlot
Definition: tools.py:125
alignment.fancystuff.tools.ProfilePlot.__init__
def __init__(self, x, y, x_axis=None, n_bins=None, label=None)
Definition: tools.py:139
alignment.fancystuff.tools.ProfilePlot.draw
def draw(self, color='black')
Definition: tools.py:175
alignment.fancystuff.tools.ProfilePlot.label
label
Matplotlib label for the plot.
Definition: tools.py:162