Belle II Software  release-05-02-19
matplotting.py
1 """Tools to plot ROOT object into matplotlib"""
2 
3 import ROOT
4 
5 import re
6 import functools
7 import numpy as np
8 
9 np.seterr(invalid='ignore')
10 
11 import sys
12 
13 flt_max = sys.float_info.max
14 
15 flt_min = sys.float_info.min
16 
17 import collections
18 from .plot import ValidationPlot
19 
20 import logging
21 
22 
23 def get_logger():
24  """Getter for the logger instance of this file."""
25  return logging.getLogger(__name__)
26 
27 
28 try:
29  import matplotlib
30  # Switch to noninteractive backend
31  matplotlib.use('Agg')
32 
33  import matplotlib.pyplot as plt
34  import matplotlib.transforms as transforms
35 except ImportError:
36  raise ImportError("matplotlib is not installed in your basf2 environment. "
37  "You may install it with 'pip install matplotlib'")
38 
39 
40 class defaults:
41  """Default values of the plotting options"""
42 
43  style = "bmh"
44 
45  legend = True
46 
47  label = True
48 
49 
50 plotable_classes = (
51  ROOT.TH1,
52  ROOT.TH2,
53  ROOT.THStack,
54  ROOT.TGraph,
55  ROOT.TGraphErrors,
56  ROOT.TMultiGraph
57 )
58 
59 
60 def is_plotable(tobject):
61  """Indicates if a module can be plotted with matplotlib using this module."""
62  return isinstance(tobject, plotable_classes)
63 
64 
65 def plot(tobject, **kwd):
66  """Plot the given plotable TObject.
67 
68  Parameters
69  ----------
70  tobject : ROOT.TObject
71  Plotable TObject.
72  legend : bool, optional
73  Create a by-side legend containing statistical information.
74  style : list(str), optional
75  List of matplotlib styles to be used for the plotting.
76 
77  Returns
78  -------
79  matplotlib.figure.Figure
80  The figure containing the plot
81  """
82  if isinstance(tobject, ROOT.TH2):
83  return plot_th2(tobject, **kwd)
84 
85  elif isinstance(tobject, ROOT.THStack):
86  return plot_thstack(tobject, **kwd)
87 
88  elif isinstance(tobject, ROOT.TMultiGraph):
89  return plot_tmultigraph(tobject, **kwd)
90 
91  elif isinstance(tobject, ROOT.TGraph):
92  return plot_tgraph(tobject, **kwd)
93 
94  elif isinstance(tobject, (ROOT.TProfile, ROOT.TH1)):
95  return plot_th1(tobject, **kwd)
96 
97  else:
98  raise ValueError("Plotting to matplot lib only supported for TH1, TProfile, and THStack.")
99 
100  return fig
101 
102 
103 def use_style(plot_function):
104  """Decorator to adjust the matplotlib style before plotting"""
105  @functools.wraps(plot_function)
106  def plot_function_with_style(*args, **kwds):
107  style = kwds.get("style", None)
108  if style is None:
109  style = defaults.style
110  matplotlib.style.use(style)
111 
112  matplotlib.rcParams["patch.linewidth"] = 2
113  matplotlib.rcParams['mathtext.fontset'] = 'custom'
114  matplotlib.rcParams['mathtext.rm'] = 'DejaVu Sans'
115  matplotlib.rcParams['mathtext.it'] = 'DejaVu Sans:italic'
116  matplotlib.rcParams['mathtext.cal'] = 'DejaVu Sans:italic'
117  matplotlib.rcParams['mathtext.bf'] = 'DejaVu Sans:bold'
118 
119  # matplotlib.rcParams['font.family'] = ['sans-serif']
120  matplotlib.rcParams['font.sans-serif'] = ['DejaVu Sans']
121  matplotlib.rcParams['font.monospace'] = ['cmtt10']
122 
123  return plot_function(*args, **kwds)
124 
125  return plot_function_with_style
126 
127 
128 @use_style
129 def plot_th2(th2,
130  label=None,
131  legend=None,
132  style=None,
133  **kwd):
134  """Plots a two dimensional histogram"""
135 
136  title = reformat_root_latex_to_matplotlib_latex(th2.GetTitle())
137  fig, ax = create_figure(title=title)
138 
139  th2_label = create_label(th2, label)
140  plot_th2_data_into(ax, th2, label=th2_label)
141 
142  put_legend(fig, ax, legend=legend)
143  return fig
144 
145 
146 @use_style
147 def plot_thstack(thstack,
148  label=None,
149  legend=None,
150  style=None,
151  **kwd):
152  """Plots a stack of histograms"""
153 
154  title = create_title(thstack)
155  fig, ax = create_figure(title=title)
156 
157  # plot reversed such that signal appears on top
158  ths = list(thstack.GetHists())
159 
160  if all(isinstance(th, ROOT.TH3) for th in ths):
161  raise ValueError("Cannot plot a stack of three dimensional histograms")
162 
163  elif all(isinstance(th, ROOT.TH2) and not isinstance(th, ROOT.TH3) for th in ths):
164  raise NotImplementedError("Cannot plot a stack of two dimensional histograms")
165 
166  elif all(isinstance(th, ROOT.TH1) and not isinstance(th, (ROOT.TH3, ROOT.TH2)) for th in ths):
167  # currently plot only non stacked
168  th1s = ths
169  for th1 in th1s:
170  th1_label = create_label(th1, label=label)
171  plot_th1_data_into(ax, th1, label=th1_label)
172 
173  # Fixing that the limits sometimes clip in the y direction
174  max_bin_content = thstack.GetMaximum("nostack")
175  ax.set_ylim(0, 1.02 * max_bin_content)
176 
177  else:
178  ValueError("Stack of histograms with mismatching dimensions")
179 
180  put_legend(fig, ax, legend=legend)
181  return fig
182 
183 
184 @use_style
185 def plot_tmultigraph(tmultigraph,
186  label=None,
187  legend=None,
188  style=None,
189  **kwd):
190  """Plots multiple overlayed graphs"""
191 
192  title = create_title(tmultigraph)
193  fig, ax = create_figure(title=title)
194 
195  for tgraph in tmultigraph.GetListOfGraphs():
196  tgraph_label = create_label(tgraph, label=label)
197  plot_tgraph_data_into(ax, tgraph, label=tgraph_label)
198 
199  y_lower_bound, y_upper_bound = common_bounds(
200  ax.get_ylim(),
201  (tmultigraph.GetMinimum(), tmultigraph.GetMaximum())
202  )
203 
204  ax.set_ylim(y_lower_bound, y_upper_bound)
205 
206  put_legend(fig, ax, legend=legend)
207  return fig
208 
209 
210 @use_style
211 def plot_th1(th1,
212  label=None,
213  legend=None,
214  style=None,
215  **kwd):
216  """Plots a one dimensional histogram including the fit function if present"""
217 
218  title = create_title(th1)
219  fig, ax = create_figure(title=title)
220 
221  th1_label = create_label(th1, label)
222 
223  if th1.GetSumOfWeights() == 0:
224  get_logger().info("Skipping empty histogram %s", th1.GetName())
225  return fig
226 
227  plot_th1_data_into(ax, th1, label=th1_label)
228 
229  tf1 = get_fit(th1)
230  if tf1:
231  tf1_label = create_label(tf1, label=label)
232  fit_handles = plot_tf1_data_into(ax, tf1, label=tf1_label)
233  else:
234  fit_handles = None
235 
236  put_legend(fig, ax, legend=legend, top_handles=fit_handles, top_title="Fit")
237  return fig
238 
239 
240 @use_style
241 def plot_tgraph(tgraph,
242  label=None,
243  legend=None,
244  style=None):
245  """Plots graph including the fit function if present"""
246  title = create_title(tgraph)
247  fig, ax = create_figure(title=title)
248 
249  tgraph_label = create_label(tgraph, label=label)
250  plot_tgraph_data_into(ax, tgraph, label=tgraph_label)
251 
252  tf1 = get_fit(tgraph)
253  if tf1:
254  tf1_label = create_label(tf1, label=label)
255  fit_handles = plot_tf1_data_into(ax, tf1, label=tf1_label)
256  else:
257  fit_handles = None
258 
259  put_legend(fig, ax, legend=legend, top_handles=fit_handles, top_title="Fit")
260  return fig
261 
262 
263 def create_figure(title=None):
264  """Create a new figure already making space for a by side legend if requested
265 
266  Returns
267  -------
268  (matplotlib.figure.Figure, matplotlib.axes.Axes)
269  A prepared figure and axes into which can be plotted.
270  """
271  fig = plt.figure()
272  ax = fig.add_subplot(111)
273 
274  if title:
275  ax.set_title(title, y=1.04)
276 
277  return fig, ax
278 
279 
280 def create_title(tplotable):
281  """Extract the title from the plotable ROOT object and translate to ROOT latex"""
282  return reformat_root_latex_to_matplotlib_latex(tplotable.GetTitle())
283 
284 
285 def create_label(th_or_tgraph, label=None):
286  """Create a label from the plotable object incorporating available summary statistics."""
287  if label is None:
288  label = defaults.label
289 
290  if label:
291  if isinstance(th_or_tgraph, ROOT.TH1):
292  th = th_or_tgraph
293  stats = get_stats_from_th(th)
294  label = compose_stats_label("", stats)
295 
296  elif isinstance(th_or_tgraph, ROOT.TGraph):
297  tgraph = th_or_tgraph
298  stats = get_stats_from_tgraph(tgraph)
299  label = compose_stats_label("", stats)
300 
301  elif isinstance(th_or_tgraph, ROOT.TF1):
302  tf1 = th_or_tgraph
303  parameters = get_fit_parameters(tf1)
304  label = compose_stats_label("", parameters)
305 
306  else:
307  raise ValueError("Can only create a label from a ROOT.TH or ROOT.TGraph")
308 
309  return label
310  return None
311 
312 
313 def put_legend(fig,
314  ax,
315  legend=None,
316  top_handles=None,
317  top_title=None,
318  bottom_title=None):
319  """Put the legend of the plot
320 
321  Put one legend to right beside the axes space for some plot handles if desired.
322  Put one legend at the bottom for the remaining plot handles.
323  """
324 
325  if legend is None:
326  legend = defaults.legend
327 
328  if legend:
329  fig_width = fig.get_figwidth()
330  # Expanding figure by 33 %
331  fig.set_figwidth(1.0 * 4.0 / 3.0 * fig_width)
332 
333  # Shink current axis by 25%
334  box = ax.get_position()
335  ax.set_position([box.x0, box.y0, box.width * 0.75, box.height])
336 
337  if top_handles:
339  exclude_handles=top_handles,
340  force_add_legend=True,
341  title=bottom_title)
343  select_handles=top_handles,
344  bottom=False,
345  title=top_title)
346  else:
348  exclude_handles=top_handles,
349  title=bottom_title)
350 
351 
353  title=None,
354  bottom=True,
355  select_handles=[],
356  exclude_handles=[],
357  force_add_legend=False,
358  ):
359  """Put a legned right beside the axes space"""
360 
361  if not select_handles:
362  select_handles, _ = ax.get_legend_handles_labels()
363 
364  if exclude_handles:
365  select_handles = [handle for handle in select_handles if handle not in exclude_handles]
366 
367  fig = ax.get_figure()
368  # trans = transforms.blended_transform_factory(fig.transFigure, ax.transAxes)
369 
370  if bottom:
371  # legend_artist = ax.legend(handles=select_handles,
372  # bbox_to_anchor=(1., 0.),
373  # bbox_transform=trans,
374  # borderaxespad=0.,
375  # loc=4,
376  # prop={"family": "monospace"},
377  # title=title)
378 
379  legend_artist = ax.legend(handles=select_handles,
380  bbox_to_anchor=(1.02, 0),
381  borderaxespad=0.,
382  loc=3,
383  prop={"family": "monospace"},
384  title=title)
385  else:
386  # legend_artist = ax.legend(handles=select_handles,
387  # bbox_to_anchor=(1., 1.),
388  # bbox_transform=trans,
389  # borderaxespad=0.,
390  # loc=1,
391  # prop={"family": "monospace"},
392  # title=title)
393 
394  legend_artist = ax.legend(handles=select_handles,
395  bbox_to_anchor=(1.02, 1),
396  borderaxespad=0.,
397  loc=2,
398  prop={"family": "monospace"},
399  title=title)
400 
401  if force_add_legend:
402  ax.add_artist(legend_artist)
403 
404 
405 def get_fit(th1_or_tgraph):
406  """Retrieve a potential fit function form the plotable object"""
407  for tobject in th1_or_tgraph.GetListOfFunctions():
408  if isinstance(tobject, ROOT.TF1):
409  return tobject
410 
411 
413  """Retrieve the fitted parameters explicitly excluding fixed parameters.
414 
415  Fixed parameters are usually additional stats entries that are already shown in the main
416  legend for the plot.
417  """
418  parameters = collections.OrderedDict()
419 
420  formula = tf1.GetTitle()
421  parameters["formula"] = formula
422 
423  n_parameters = tf1.GetNpar()
424  for i_parameter in range(n_parameters):
425 
426  lower_bound = ROOT.Double()
427  upper_bound = ROOT.Double()
428  tf1.GetParLimits(i_parameter, lower_bound, upper_bound)
429 
430  name = tf1.GetParName(i_parameter)
431  value = tf1.GetParameter(i_parameter)
432 
433  if lower_bound == upper_bound and lower_bound != 0:
434  # fixed parameter, is an additional stats entry
435  continue
436 
437  parameters[name] = value
438 
439  return parameters
440 
441 
443  """Get the summary statistics from the graph"""
444  stats = collections.OrderedDict()
445 
446  additional_stats = ValidationPlot.get_additional_stats(tgraph)
447  if additional_stats:
448  for key, value in list(additional_stats.items()):
449  stats[key] = value
450 
451  return stats
452 
453 
455  """Get the summary statistics from the histogram"""
456  stats = collections.OrderedDict()
457 
458  stats["count"] = th.GetEntries()
459 
460  additional_stats = ValidationPlot.get_additional_stats(th)
461  if additional_stats:
462  for key, value in list(additional_stats.items()):
463  stats[key] = value
464 
465  if not isinstance(th, (ROOT.TH2, ROOT.TH3)):
466  x_taxis = th.GetXaxis()
467  n_bins = x_taxis.GetNbins()
468 
469  if isinstance(th, ROOT.TProfile):
470  underflow_content = th.GetBinEntries(0)
471  overflow_content = th.GetBinEntries(n_bins + 1)
472  else:
473  underflow_content = th.GetBinContent(0)
474  overflow_content = th.GetBinContent(n_bins + 1)
475 
476  if underflow_content:
477  stats["x underf."] = underflow_content
478 
479  if overflow_content:
480  stats["x overf."] = overflow_content
481 
482  stats_values = np.array([np.nan] * 7)
483  th.GetStats(stats_values)
484 
485  sum_w = stats_values[0]
486  sum_w2 = stats_values[1]
487  sum_wx = stats_values[2]
488  sum_wx2 = stats_values[3]
489  sum_wy = stats_values[4] # Only for TH2 and TProfile
490  sum_wy2 = stats_values[5] # Only for TH2 and TProfile
491  sum_wxy = stats_values[6] # Only for TH2
492 
493  if np.isnan(sum_wy):
494  # Only one dimensional
495  stats["x avg"] = np.divide(sum_wx, sum_w)
496  stats["x std"] = np.divide(np.sqrt(sum_wx2 * sum_w - sum_wx * sum_wx), sum_w)
497 
498  else:
499  # Only two dimensional
500  stats["x avg"] = np.divide(sum_wx, sum_w)
501  stats["x std"] = np.divide(np.sqrt(sum_wx2 * sum_w - sum_wx * sum_wx), sum_w)
502  stats["y avg"] = np.divide(sum_wy, sum_w)
503  stats["y std"] = np.divide(np.sqrt(sum_wy2 * sum_w - sum_wy * sum_wy), sum_w)
504 
505  if not np.isnan(sum_wxy):
506  stats["cov"] = np.divide((sum_wxy * sum_w - sum_wx * sum_wy), (sum_w * sum_w))
507  stats["corr"] = np.divide(stats["cov"], (stats["x std"] * stats["y std"]))
508 
509  return stats
510 
511 
512 def compose_stats_label(title, additional_stats={}):
513  """Render the summary statistics to a label string."""
514  keys = list(additional_stats.keys())
515  labeled_value_template = "{0:<9}: {1:.3g}"
516  labeled_string_template = "{0:<9}: {1:>9s}"
517  label_elements = []
518  if title:
519  label_elements.append(str(title))
520 
521  for key, value in list(additional_stats.items()):
522  if isinstance(value, str):
523  label_element = labeled_string_template.format(key, value)
524  else:
525  label_element = labeled_value_template.format(key, value)
526  label_elements.append(label_element)
527 
528  return "\n".join(label_elements)
529 
530 
532  tgraph,
533  plot_errors=None,
534  label=None,
535  clip_to_data=True):
536  """Plot a ROOT TGraph into a matplotlib axes
537 
538  Parameters
539  ----------
540  ax : matplotlib.axes.Axes
541  An axes space in which to plot
542  tgraph : ROOT.TGraph
543  A plotable one dimensional ROOT histogram
544  plot_errors : bool, optional
545  Plot graph as errorbar plot. Default None means True for TGraphErrors and False else.
546  label : str
547  label to be given to the plot
548  """
549 
550  plt.autoscale(tight=clip_to_data)
551 
552  if plot_errors is None:
553  if isinstance(tgraph, ROOT.TGraphErrors):
554  plot_errors = True
555  else:
556  plot_errors = False
557 
558  x_taxis = tgraph.GetXaxis()
559  y_taxis = tgraph.GetYaxis()
560 
561  xlabel = reformat_root_latex_to_matplotlib_latex(x_taxis.GetTitle())
562  ylabel = reformat_root_latex_to_matplotlib_latex(y_taxis.GetTitle())
563 
564  n_points = tgraph.GetN()
565 
566  xs = np.ndarray((n_points,), dtype=float)
567  ys = np.ndarray((n_points,), dtype=float)
568 
569  x_lower_errors = np.ndarray((n_points,), dtype=float)
570  x_upper_errors = np.ndarray((n_points,), dtype=float)
571 
572  y_lower_errors = np.ndarray((n_points,), dtype=float)
573  y_upper_errors = np.ndarray((n_points,), dtype=float)
574 
575  x = ROOT.Double()
576  y = ROOT.Double()
577 
578  for i_point in range(n_points):
579  tgraph.GetPoint(i_point, x, y)
580  xs[i_point] = float(x)
581  ys[i_point] = float(y)
582 
583  x_lower_errors[i_point] = tgraph.GetErrorXlow(i_point)
584  x_upper_errors[i_point] = tgraph.GetErrorXhigh(i_point)
585 
586  y_lower_errors[i_point] = tgraph.GetErrorYlow(i_point)
587  y_upper_errors[i_point] = tgraph.GetErrorYhigh(i_point)
588 
589  if plot_errors:
590  root_color_index = tgraph.GetLineColor()
591  linecolor = root_color_to_matplot_color(root_color_index)
592 
593  ax.errorbar(xs,
594  ys,
595  xerr=[x_lower_errors, x_upper_errors],
596  yerr=[y_lower_errors, y_upper_errors],
597  fmt="none",
598  ecolor=linecolor,
599  label=label)
600 
601  else:
602  root_color_index = tgraph.GetMarkerColor()
603  markercolor = root_color_to_matplot_color(root_color_index)
604 
605  ax.scatter(xs,
606  ys,
607  c=markercolor,
608  s=2,
609  marker="+",
610  label=label)
611  x_lower_bound, x_upper_bound = ax.get_xlim()
612  x_lower_bound = min(x_lower_bound, np.nanmin(xs))
613 
614  ax.set_xlabel(xlabel)
615  ax.set_ylabel(ylabel)
616 
617  plt.autoscale(tight=None)
618 
619 
621  th1,
622  plot_errors=None,
623  label=None):
624  """Plot a ROOT histogram into a matplotlib axes
625 
626  Parameters
627  ----------
628  ax : matplotlib.axes.Axes
629  An axes space in which to plot
630  th1 : ROOT.TH1
631  A plotable one dimensional ROOT histogram
632  plot_errors : bool, optional
633  Plot histogram as errorbar plot. Default None means True for TProfile and False else.
634  label : str, optional
635  label to be given to the histogram
636  """
637 
638  if plot_errors is None:
639  if isinstance(th1, ROOT.TProfile):
640  plot_errors = True
641  else:
642  plot_errors = th1.GetSumw2N() != 0
643 
644  # Bin content
645  x_taxis = th1.GetXaxis()
646  n_bins = x_taxis.GetNbins()
647 
648  bin_ids_with_underflow = list(range(n_bins + 1))
649  bin_ids_without_underflow = list(range(1, n_bins + 1))
650 
651  # Get the n_bins + 1 bin edges starting from the underflow bin 0
652  bin_edges = np.array([x_taxis.GetBinUpEdge(i_bin) for i_bin in bin_ids_with_underflow])
653 
654  # Bin center and width not including the underflow
655  bin_centers = np.array([x_taxis.GetBinCenter(i_bin) for i_bin in bin_ids_without_underflow])
656  bin_widths = np.array([x_taxis.GetBinWidth(i_bin) for i_bin in bin_ids_without_underflow])
657  bin_x_errors = bin_widths / 2.0
658 
659  # Now for the histogram content not including the underflow
660  bin_contents = np.array([th1.GetBinContent(i_bin) for i_bin in bin_ids_without_underflow])
661  bin_y_errors = np.array([th1.GetBinError(i_bin) for i_bin in bin_ids_without_underflow])
662  bin_y_upper_errors = np.array([th1.GetBinErrorUp(i_bin) for i_bin in bin_ids_without_underflow])
663  bin_y_lower_errors = np.array([th1.GetBinErrorLow(i_bin) for i_bin in bin_ids_without_underflow])
664 
665  empty_bins = (bin_contents == 0) & (bin_y_errors == 0)
666 
667  is_discrete_binning = bool(x_taxis.GetLabels())
668  bin_labels = [x_taxis.GetBinLabel(i_bin) for i_bin in bin_ids_without_underflow]
669 
670  xlabel = reformat_root_latex_to_matplotlib_latex(x_taxis.GetTitle())
671 
672  y_taxis = th1.GetYaxis()
673  ylabel = reformat_root_latex_to_matplotlib_latex(y_taxis.GetTitle())
674 
675  # May set these from th1 properties
676  y_log_scale = False
677  histtype = "step"
678 
679  root_color_index = th1.GetLineColor()
680  linecolor = root_color_to_matplot_color(root_color_index)
681 
682  if plot_errors:
683  ax.errorbar(bin_centers[~empty_bins],
684  bin_contents[~empty_bins],
685  yerr=[bin_y_lower_errors[~empty_bins],
686  bin_y_upper_errors[~empty_bins]],
687  xerr=bin_x_errors[~empty_bins],
688  fmt="none",
689  ecolor=linecolor,
690  label=label)
691 
692  y_lower_bound, y_upper_bound = common_bounds(
693  ax.get_ylim(),
694  (th1.GetMinimum(flt_min), th1.GetMaximum(flt_max))
695  )
696 
697  y_total_width = y_upper_bound - y_lower_bound
698  ax.set_ylim(y_lower_bound - 0.02 * y_total_width, y_upper_bound + 0.02 * y_total_width)
699 
700  else:
701  if is_discrete_binning:
702  ax.bar(bin_centers - 0.4,
703  bin_contents,
704  width=0.8,
705  label=label,
706  edgecolor=linecolor,
707  color=(1, 1, 1, 0), # fill transparent white
708  log=y_log_scale)
709  else:
710  ax.hist(bin_centers,
711  bins=bin_edges,
712  weights=bin_contents,
713  edgecolor=linecolor,
714  histtype=histtype,
715  label=label,
716  log=y_log_scale)
717 
718  # Fixing that the limits sometimes clip in the y direction
719  ax.set_ylim(0, 1.02 * max(bin_contents))
720 
721  if is_discrete_binning:
722  ax.set_xticks(bin_centers)
723  ax.set_xticklabels(bin_labels, rotation=0)
724 
725  total_width = bin_edges[-1] - bin_edges[0]
726  if total_width != 0:
727  ax.set_xlim(bin_edges[0] - 0.01 * total_width, bin_edges[-1] + 0.01 * total_width)
728 
729  ax.set_xlabel(xlabel)
730  ax.set_ylabel(ylabel)
731 
732 
734  th2,
735  plot_3d=False,
736  label=None):
737  """Plot a ROOT histogram into a matplotlib axes
738 
739  Parameters
740  ----------
741  ax : matplotlib.axes.Axes
742  An axes space in which to plot
743  th2 : ROOT.TH2
744  A plotable two dimensional ROOT histogram
745  plot_3d : bool, optional
746  Plot as a three dimensional plot
747  label : str, optional
748  label to be given to the histogram
749  """
750 
751  # Bin content
752  x_taxis = th2.GetXaxis()
753  y_taxis = th2.GetYaxis()
754 
755  x_n_bins = x_taxis.GetNbins()
756  y_n_bins = y_taxis.GetNbins()
757 
758  x_bin_ids_with_underflow = list(range(x_n_bins + 1))
759  y_bin_ids_with_underflow = list(range(y_n_bins + 1))
760 
761  x_bin_ids_without_underflow = list(range(1, x_n_bins + 1))
762  y_bin_ids_without_underflow = list(range(1, y_n_bins + 1))
763 
764  # Get the n_bins + 1 bin edges starting from the underflow bin 0
765  x_bin_edges = np.array([x_taxis.GetBinUpEdge(i_bin) for i_bin in x_bin_ids_with_underflow])
766  y_bin_edges = np.array([y_taxis.GetBinUpEdge(i_bin) for i_bin in y_bin_ids_with_underflow])
767 
768  # Bin center and width not including the underflow
769  x_bin_centers = np.array([x_taxis.GetBinCenter(i_bin) for i_bin in x_bin_ids_without_underflow])
770  y_bin_centers = np.array([y_taxis.GetBinCenter(i_bin) for i_bin in y_bin_ids_without_underflow])
771 
772  x_centers, y_centers = np.meshgrid(x_bin_centers, y_bin_centers)
773 
774  x_ids, y_ids = np.meshgrid(x_bin_ids_without_underflow, y_bin_ids_without_underflow)
775 
776  bin_contents = np.array([th2.GetBinContent(int(x_i_bin), int(y_i_bin))
777  for x_i_bin, y_i_bin in zip(x_ids.flat, y_ids.flat)])
778 
779  x_is_discrete_binning = bool(x_taxis.GetLabels())
780  x_bin_labels = [x_taxis.GetBinLabel(i_bin) for i_bin in x_bin_ids_without_underflow]
781 
782  y_is_discrete_binning = bool(y_taxis.GetLabels())
783  y_bin_labels = [y_taxis.GetBinLabel(i_bin) for i_bin in y_bin_ids_without_underflow]
784 
785  xlabel = reformat_root_latex_to_matplotlib_latex(x_taxis.GetTitle())
786  ylabel = reformat_root_latex_to_matplotlib_latex(y_taxis.GetTitle())
787 
788  # May set these from th2 properties
789  log_scale = False
790 
791  root_color_index = th2.GetLineColor()
792  linecolor = root_color_to_matplot_color(root_color_index)
793 
794  if plot_3d:
795  raise NotImplementedError("3D plotting of two dimensional histograms not implemented yet")
796 
797  else:
798  if log_scale:
799  _, _, _, colorbar_mappable = ax.hist2d(x_centers.flatten(),
800  y_centers.flatten(),
801  weights=bin_contents,
802  bins=[x_bin_edges, y_bin_edges],
803  label=label,
804  norm=matplotlib.colors.LogNorm())
805  else:
806  _, _, _, colorbar_mappable = ax.hist2d(x_centers.flatten(),
807  y_centers.flatten(),
808  weights=bin_contents,
809  bins=[x_bin_edges, y_bin_edges],
810  label=label)
811 
812  lowest_color = colorbar_mappable.get_cmap()(0)
813 
814  # Dummy artist to show in legend
815  ax.fill(0, 0, "-", color=lowest_color, label=label)
816 
817  # For colorbar on the left
818  # cbar_ax, _ = matplotlib.colorbar.make_axes(ax,
819  # location="left",
820  # )
821 
822  colorbar_ax, _ = matplotlib.colorbar.make_axes(ax,
823  pad=0.02,
824  shrink=0.5,
825  anchor=(0.0, 1.0),
826  panchor=(1.0, 1.0),
827  aspect=10,
828  )
829 
830  matplotlib.colorbar.Colorbar(colorbar_ax, colorbar_mappable)
831 
832  if x_is_discrete_binning:
833  ax.set_xticks(x_bin_centers)
834  ax.set_xticklabels(x_bin_labels, rotation=0)
835 
836  if y_is_discrete_binning:
837  ax.set_yticks(y_bin_centers)
838  ax.set_yticklabels(y_bin_labels, rotation=0)
839 
840  ax.set_xlabel(xlabel)
841  ax.set_ylabel(ylabel)
842 
843 
845  tf1,
846  label=None):
847  """Plots the data of the tf1 into a matplotlib axes
848 
849  Parameters
850  ----------
851  ax : matplotlib.axes.Axes
852  An axes space in which to plot
853  tf1 : ROOT.TF1
854  Function to be ploted.
855  label : str, optional
856  Label for the legend entry.
857  """
858 
859  lower_bound = tf1.GetXmin()
860  upper_bound = tf1.GetXmin()
861  n_plot_points = max(tf1.GetNpx(), 100)
862 
863  if lower_bound == upper_bound:
864  lower_bound, upper_bound = ax.get_xlim()
865 
866  xs = np.linspace(lower_bound, upper_bound, n_plot_points)
867  ys = [tf1.Eval(x) for x in xs]
868 
869  root_color_index = tf1.GetLineColor()
870  linecolor = root_color_to_matplot_color(root_color_index)
871 
872  if any(y != 0 for y in ys):
873  line_handles = ax.plot(xs, ys, color=linecolor, label=label)
874  return line_handles
875 
876 
877 def root_color_to_matplot_color(root_color_index):
878  """Translates ROOT color into an RGB tuple.
879 
880  Parameters
881  ----------
882  root_color_index : int
883  Index of a color as defined in ROOT
884 
885  Returns
886  -------
887  (float, float, float)
888  tuple of floats that represent to RGB color fractions.
889  """
890  tcolor = ROOT.gROOT.GetColor(root_color_index)
891  return (tcolor.GetRed(), tcolor.GetGreen(), tcolor.GetBlue())
892 
893 
895  """Takes text that may contain ROOT pseudo latex directives and
896  translate it in to proper latex that can be understood by matplotlib"""
897 
898  # Dumb implementation, can be improved a lot
899  # Splits by white space and try to treat every part separatly.
900  # Additionally a dump regular expression replacement for ROOT latex directives
901  # is applied.
902  # It may loose some context,
903  # but I could not bother to implement a parser for ROOTs latex dialect.
904 
905  text_parts = text.split(" ")
906 
907  reformatted_text_parts = []
908 
909  for text_part in text_parts:
910  if is_root_latex_directive(text_part):
911  # All directive are wrapped into math mode
912  reformatted_text_part = r'$' + text_part.replace('#', '\\') + r'$'
913  # print 'Format' , text_part ,'to' , reformatted_text_part
914  else:
915  reformatted_text_part = text_part
916 
917  reformatted_text_part = re.sub("#([a-zA-Z_{}]*)", r"$\\\1$", reformatted_text_part)
918 
919  reformatted_text_parts.append(reformatted_text_part)
920 
921  return " ".join(reformatted_text_parts)
922 
923 
924 def is_root_latex_directive(text_part):
925  """Test if a text part looks like a ROOT latex directive"""
926  return text_part.startswith('#') or '_{' in text_part or '{}' in text_part
927 
928 
929 def common_bounds(matplot_bounds, root_bounds):
930  """Assign the common lower and upper bounds for a plot"""
931  lower_bound, upper_bound = matplot_bounds
932  root_lower_bound, root_upper_bound = root_bounds
933 
934  if root_lower_bound != 0 or root_upper_bound != 0:
935  lower_bound = np.nanmin((lower_bound, root_lower_bound))
936  upper_bound = np.nanmax((upper_bound, root_upper_bound))
937 
938  return lower_bound, upper_bound
tracking.validation.matplotting.compose_stats_label
def compose_stats_label(title, additional_stats={})
Definition: matplotting.py:512
tracking.validation.matplotting.plot_tgraph
def plot_tgraph(tgraph, label=None, legend=None, style=None)
Definition: matplotting.py:241
tracking.validation.matplotting.plot_th1
def plot_th1(th1, label=None, legend=None, style=None, **kwd)
Definition: matplotting.py:211
tracking.validation.matplotting.get_logger
def get_logger()
Definition: matplotting.py:23
tracking.validation.matplotting.is_plotable
def is_plotable(tobject)
Definition: matplotting.py:60
tracking.validation.matplotting.plot_tmultigraph
def plot_tmultigraph(tmultigraph, label=None, legend=None, style=None, **kwd)
Definition: matplotting.py:185
tracking.validation.matplotting.get_stats_from_th
def get_stats_from_th(th)
Definition: matplotting.py:454
tracking.validation.matplotting.plot_th2_data_into
def plot_th2_data_into(ax, th2, plot_3d=False, label=None)
Definition: matplotting.py:733
tracking.validation.matplotting.put_legend_outside
def put_legend_outside(ax, title=None, bottom=True, select_handles=[], exclude_handles=[], force_add_legend=False)
Definition: matplotting.py:352
tracking.validation.matplotting.common_bounds
def common_bounds(matplot_bounds, root_bounds)
Definition: matplotting.py:929
tracking.validation.matplotting.root_color_to_matplot_color
def root_color_to_matplot_color(root_color_index)
Definition: matplotting.py:877
tracking.validation.matplotting.defaults
Definition: matplotting.py:40
tracking.validation.matplotting.create_label
def create_label(th_or_tgraph, label=None)
Definition: matplotting.py:285
tracking.validation.matplotting.plot_thstack
def plot_thstack(thstack, label=None, legend=None, style=None, **kwd)
Definition: matplotting.py:147
tracking.validation.matplotting.get_stats_from_tgraph
def get_stats_from_tgraph(tgraph)
Definition: matplotting.py:442
tracking.validation.matplotting.create_figure
def create_figure(title=None)
Definition: matplotting.py:263
tracking.validation.matplotting.put_legend
def put_legend(fig, ax, legend=None, top_handles=None, top_title=None, bottom_title=None)
Definition: matplotting.py:313
tracking.validation.matplotting.plot_th1_data_into
def plot_th1_data_into(ax, th1, plot_errors=None, label=None)
Definition: matplotting.py:620
tracking.validation.matplotting.reformat_root_latex_to_matplotlib_latex
def reformat_root_latex_to_matplotlib_latex(text)
Definition: matplotting.py:894
tracking.validation.matplotting.is_root_latex_directive
def is_root_latex_directive(text_part)
Definition: matplotting.py:924
tracking.validation.matplotting.get_fit
def get_fit(th1_or_tgraph)
Definition: matplotting.py:405
tracking.validation.matplotting.plot
def plot(tobject, **kwd)
Definition: matplotting.py:65
tracking.validation.matplotting.use_style
def use_style(plot_function)
Definition: matplotting.py:103
tracking.validation.matplotting.create_title
def create_title(tplotable)
Definition: matplotting.py:280
tracking.validation.matplotting.get_fit_parameters
def get_fit_parameters(tf1)
Definition: matplotting.py:412
tracking.validation.matplotting.plot_tgraph_data_into
def plot_tgraph_data_into(ax, tgraph, plot_errors=None, label=None, clip_to_data=True)
Definition: matplotting.py:531
tracking.validation.matplotting.plot_tf1_data_into
def plot_tf1_data_into(ax, tf1, label=None)
Definition: matplotting.py:844
tracking.validation.matplotting.plot_th2
def plot_th2(th2, label=None, legend=None, style=None, **kwd)
Definition: matplotting.py:129