Belle II Software  release-08-01-10
matplotting.py
1 
8 
9 """Tools to plot ROOT object into matplotlib"""
10 
11 import ROOT
12 
13 import re
14 import functools
15 import numpy as np
16 import collections
17 from tracking.validation.plot import ValidationPlot
18 import ctypes
19 
20 import logging
21 
22 import sys
23 
24 np.seterr(invalid='ignore')
25 
26 flt_max = sys.float_info.max
27 
28 flt_min = sys.float_info.min
29 
30 
31 def get_logger():
32  """Getter for the logger instance of this file."""
33  return logging.getLogger(__name__)
34 
35 
36 try:
37  import matplotlib
38  # Switch to noninteractive backend
39  matplotlib.use('Agg')
40 
41  import matplotlib.pyplot as plt
42 except ImportError:
43  raise ImportError("matplotlib is not installed in your basf2 environment. "
44  "You may install it with 'pip install matplotlib'")
45 
46 
47 class defaults:
48  """Default values of the plotting options"""
49 
50  style = "bmh"
51 
52  legend = True
53 
54  label = True
55 
56 
57 
58 plotable_classes = (
59  ROOT.TH1,
60  ROOT.TH2,
61  ROOT.THStack,
62  ROOT.TGraph,
63  ROOT.TGraphErrors,
64  ROOT.TMultiGraph
65 )
66 
67 
68 def is_plotable(tobject):
69  """Indicates if a module can be plotted with matplotlib using this module."""
70  return isinstance(tobject, plotable_classes)
71 
72 
73 def plot(tobject, **kwd):
74  """Plot the given plotable TObject.
75 
76  Parameters
77  ----------
78  tobject : ROOT.TObject
79  Plotable TObject.
80  legend : bool, optional
81  Create a by-side legend containing statistical information.
82  style : list(str), optional
83  List of matplotlib styles to be used for the plotting.
84 
85  Returns
86  -------
87  matplotlib.figure.Figure
88  The figure containing the plot
89  """
90  if isinstance(tobject, ROOT.TH2):
91  return plot_th2(tobject, **kwd)
92 
93  elif isinstance(tobject, ROOT.THStack):
94  return plot_thstack(tobject, **kwd)
95 
96  elif isinstance(tobject, ROOT.TMultiGraph):
97  return plot_tmultigraph(tobject, **kwd)
98 
99  elif isinstance(tobject, ROOT.TGraph):
100  return plot_tgraph(tobject, **kwd)
101 
102  elif isinstance(tobject, (ROOT.TProfile, ROOT.TH1)):
103  return plot_th1(tobject, **kwd)
104 
105  else:
106  raise ValueError("Plotting to matplot lib only supported for TH1, TProfile, and THStack.")
107 
108 
109 def use_style(plot_function):
110  """Decorator to adjust the matplotlib style before plotting"""
111  @functools.wraps(plot_function)
112  def plot_function_with_style(*args, **kwds):
113  style = kwds.get("style", None)
114  if style is None:
115  style = defaults.style
116  matplotlib.style.use(style)
117 
118  matplotlib.rcParams["patch.linewidth"] = 2
119  matplotlib.rcParams['mathtext.fontset'] = 'custom'
120  matplotlib.rcParams['mathtext.rm'] = 'DejaVu Sans'
121  matplotlib.rcParams['mathtext.it'] = 'DejaVu Sans:italic'
122  matplotlib.rcParams['mathtext.cal'] = 'DejaVu Sans:italic'
123  matplotlib.rcParams['mathtext.bf'] = 'DejaVu Sans:bold'
124 
125  # matplotlib.rcParams['font.family'] = ['sans-serif']
126  matplotlib.rcParams['font.sans-serif'] = ['DejaVu Sans']
127  matplotlib.rcParams['font.monospace'] = ['cmtt10']
128 
129  return plot_function(*args, **kwds)
130 
131  return plot_function_with_style
132 
133 
134 @use_style
135 def plot_th2(th2,
136  label=None,
137  legend=None,
138  style=None,
139  **kwd):
140  """Plots a two dimensional histogram"""
141 
142  title = reformat_root_latex_to_matplotlib_latex(th2.GetTitle())
143  fig, ax = create_figure(title=title)
144 
145  th2_label = create_label(th2, label)
146  plot_th2_data_into(ax, th2, label=th2_label)
147 
148  put_legend(fig, ax, legend=legend)
149  return fig
150 
151 
152 @use_style
153 def plot_thstack(thstack,
154  label=None,
155  legend=None,
156  style=None,
157  **kwd):
158  """Plots a stack of histograms"""
159 
160  title = create_title(thstack)
161  fig, ax = create_figure(title=title)
162 
163  # plot reversed such that signal appears on top
164  ths = list(thstack.GetHists())
165 
166  if all(isinstance(th, ROOT.TH3) for th in ths):
167  raise ValueError("Cannot plot a stack of three dimensional histograms")
168 
169  elif all(isinstance(th, ROOT.TH2) and not isinstance(th, ROOT.TH3) for th in ths):
170  raise NotImplementedError("Cannot plot a stack of two dimensional histograms")
171 
172  elif all(isinstance(th, ROOT.TH1) and not isinstance(th, (ROOT.TH3, ROOT.TH2)) for th in ths):
173  # currently plot only non stacked
174  th1s = ths
175  for th1 in th1s:
176  th1_label = create_label(th1, label=label)
177  plot_th1_data_into(ax, th1, label=th1_label)
178 
179  # Fixing that the limits sometimes clip in the y direction
180  max_bin_content = thstack.GetMaximum("nostack")
181  ax.set_ylim(0, 1.02 * max_bin_content)
182 
183  else:
184  ValueError("Stack of histograms with mismatching dimensions")
185 
186  put_legend(fig, ax, legend=legend)
187  return fig
188 
189 
190 @use_style
191 def plot_tmultigraph(tmultigraph,
192  label=None,
193  legend=None,
194  style=None,
195  **kwd):
196  """Plots multiple overlayed graphs"""
197 
198  title = create_title(tmultigraph)
199  fig, ax = create_figure(title=title)
200 
201  for tgraph in tmultigraph.GetListOfGraphs():
202  tgraph_label = create_label(tgraph, label=label)
203  plot_tgraph_data_into(ax, tgraph, label=tgraph_label)
204 
205  y_lower_bound, y_upper_bound = common_bounds(
206  ax.get_ylim(),
207  (tmultigraph.GetMinimum(), tmultigraph.GetMaximum())
208  )
209 
210  ax.set_ylim(y_lower_bound, y_upper_bound)
211 
212  put_legend(fig, ax, legend=legend)
213  return fig
214 
215 
216 @use_style
217 def plot_th1(th1,
218  label=None,
219  legend=None,
220  style=None,
221  **kwd):
222  """Plots a one dimensional histogram including the fit function if present"""
223 
224  title = create_title(th1)
225  fig, ax = create_figure(title=title)
226 
227  th1_label = create_label(th1, label)
228 
229  if th1.GetSumOfWeights() == 0:
230  get_logger().info("Skipping empty histogram %s", th1.GetName())
231  return fig
232 
233  plot_th1_data_into(ax, th1, label=th1_label)
234 
235  tf1 = get_fit(th1)
236  if tf1:
237  tf1_label = create_label(tf1, label=label)
238  fit_handles = plot_tf1_data_into(ax, tf1, label=tf1_label)
239  else:
240  fit_handles = None
241 
242  put_legend(fig, ax, legend=legend, top_handles=fit_handles, top_title="Fit")
243  return fig
244 
245 
246 @use_style
247 def plot_tgraph(tgraph,
248  label=None,
249  legend=None,
250  style=None):
251  """Plots graph including the fit function if present"""
252  title = create_title(tgraph)
253  fig, ax = create_figure(title=title)
254 
255  tgraph_label = create_label(tgraph, label=label)
256  plot_tgraph_data_into(ax, tgraph, label=tgraph_label)
257 
258  tf1 = get_fit(tgraph)
259  if tf1:
260  tf1_label = create_label(tf1, label=label)
261  fit_handles = plot_tf1_data_into(ax, tf1, label=tf1_label)
262  else:
263  fit_handles = None
264 
265  put_legend(fig, ax, legend=legend, top_handles=fit_handles, top_title="Fit")
266  return fig
267 
268 
269 def create_figure(title=None):
270  """Create a new figure already making space for a by side legend if requested
271 
272  Returns
273  -------
274  (matplotlib.figure.Figure, matplotlib.axes.Axes)
275  A prepared figure and axes into which can be plotted.
276  """
277  fig = plt.figure()
278  ax = fig.add_subplot(111)
279 
280  if title:
281  ax.set_title(title, y=1.04)
282 
283  return fig, ax
284 
285 
286 def create_title(tplotable):
287  """Extract the title from the plotable ROOT object and translate to ROOT latex"""
288  return reformat_root_latex_to_matplotlib_latex(tplotable.GetTitle())
289 
290 
291 def create_label(th_or_tgraph, label=None):
292  """Create a label from the plotable object incorporating available summary statistics."""
293  if label is None:
294  label = defaults.label
295 
296  if label:
297  if isinstance(th_or_tgraph, ROOT.TH1):
298  th = th_or_tgraph
299  stats = get_stats_from_th(th)
300  label = compose_stats_label("", stats)
301 
302  elif isinstance(th_or_tgraph, ROOT.TGraph):
303  tgraph = th_or_tgraph
304  stats = get_stats_from_tgraph(tgraph)
305  label = compose_stats_label("", stats)
306 
307  elif isinstance(th_or_tgraph, ROOT.TF1):
308  tf1 = th_or_tgraph
309  parameters = get_fit_parameters(tf1)
310  label = compose_stats_label("", parameters)
311 
312  else:
313  raise ValueError("Can only create a label from a ROOT.TH or ROOT.TGraph")
314 
315  return label
316  return None
317 
318 
319 def put_legend(fig,
320  ax,
321  legend=None,
322  top_handles=None,
323  top_title=None,
324  bottom_title=None):
325  """Put the legend of the plot
326 
327  Put one legend to right beside the axes space for some plot handles if desired.
328  Put one legend at the bottom for the remaining plot handles.
329  """
330 
331  if legend is None:
332  legend = defaults.legend
333 
334  if legend:
335  fig_width = fig.get_figwidth()
336  # Expanding figure by 33 %
337  fig.set_figwidth(1.0 * 4.0 / 3.0 * fig_width)
338 
339  # Shink current axis by 25%
340  box = ax.get_position()
341  ax.set_position([box.x0, box.y0, box.width * 0.75, box.height])
342 
343  if top_handles:
344  put_legend_outside(ax,
345  exclude_handles=top_handles,
346  force_add_legend=True,
347  title=bottom_title)
348  put_legend_outside(ax,
349  select_handles=top_handles,
350  bottom=False,
351  title=top_title)
352  else:
353  put_legend_outside(ax,
354  exclude_handles=top_handles,
355  title=bottom_title)
356 
357 
358 def put_legend_outside(ax,
359  title=None,
360  bottom=True,
361  select_handles=[],
362  exclude_handles=[],
363  force_add_legend=False,
364  ):
365  """Put a legned right beside the axes space"""
366 
367  if not select_handles:
368  select_handles, _ = ax.get_legend_handles_labels()
369 
370  if exclude_handles:
371  select_handles = [handle for handle in select_handles if handle not in exclude_handles]
372 
373  # fig = ax.get_figure()
374  # trans = transforms.blended_transform_factory(fig.transFigure, ax.transAxes)
375 
376  if bottom:
377  # legend_artist = ax.legend(handles=select_handles,
378  # bbox_to_anchor=(1., 0.),
379  # bbox_transform=trans,
380  # borderaxespad=0.,
381  # loc=4,
382  # prop={"family": "monospace"},
383  # title=title)
384 
385  legend_artist = ax.legend(handles=select_handles,
386  bbox_to_anchor=(1.02, 0),
387  borderaxespad=0.,
388  loc=3,
389  prop={"family": "monospace"},
390  title=title)
391  else:
392  # legend_artist = ax.legend(handles=select_handles,
393  # bbox_to_anchor=(1., 1.),
394  # bbox_transform=trans,
395  # borderaxespad=0.,
396  # loc=1,
397  # prop={"family": "monospace"},
398  # title=title)
399 
400  legend_artist = ax.legend(handles=select_handles,
401  bbox_to_anchor=(1.02, 1),
402  borderaxespad=0.,
403  loc=2,
404  prop={"family": "monospace"},
405  title=title)
406 
407  if force_add_legend:
408  ax.add_artist(legend_artist)
409 
410 
411 def get_fit(th1_or_tgraph):
412  """Retrieve a potential fit function form the plotable object"""
413  for tobject in th1_or_tgraph.GetListOfFunctions():
414  if isinstance(tobject, ROOT.TF1):
415  return tobject
416 
417 
418 def get_fit_parameters(tf1):
419  """Retrieve the fitted parameters explicitly excluding fixed parameters.
420 
421  Fixed parameters are usually additional stats entries that are already shown in the main
422  legend for the plot.
423  """
424  parameters = collections.OrderedDict()
425 
426  formula = tf1.GetTitle()
427  parameters["formula"] = formula
428 
429  n_parameters = tf1.GetNpar()
430  for i_parameter in range(n_parameters):
431 
432  lower_bound = ctypes.c_double()
433  upper_bound = ctypes.c_double()
434  tf1.GetParLimits(i_parameter, lower_bound, upper_bound)
435 
436  name = tf1.GetParName(i_parameter)
437  value = tf1.GetParameter(i_parameter)
438 
439  if lower_bound.value == upper_bound.value and lower_bound.value != 0:
440  # fixed parameter, is an additional stats entry
441  continue
442 
443  parameters[name] = value
444 
445  return parameters
446 
447 
448 def get_stats_from_tgraph(tgraph):
449  """Get the summary statistics from the graph"""
450  stats = collections.OrderedDict()
451 
452  additional_stats = ValidationPlot.get_additional_stats(tgraph)
453  if additional_stats:
454  for key, value in list(additional_stats.items()):
455  stats[key] = value
456 
457  return stats
458 
459 
460 def get_stats_from_th(th):
461  """Get the summary statistics from the histogram"""
462  stats = collections.OrderedDict()
463 
464  stats["count"] = th.GetEntries()
465 
466  additional_stats = ValidationPlot.get_additional_stats(th)
467  if additional_stats:
468  for key, value in list(additional_stats.items()):
469  stats[key] = value
470 
471  if not isinstance(th, (ROOT.TH2, ROOT.TH3)):
472  x_taxis = th.GetXaxis()
473  n_bins = x_taxis.GetNbins()
474 
475  if isinstance(th, ROOT.TProfile):
476  underflow_content = th.GetBinEntries(0)
477  overflow_content = th.GetBinEntries(n_bins + 1)
478  else:
479  underflow_content = th.GetBinContent(0)
480  overflow_content = th.GetBinContent(n_bins + 1)
481 
482  if underflow_content:
483  stats["x underf."] = underflow_content
484 
485  if overflow_content:
486  stats["x overf."] = overflow_content
487 
488  stats_values = np.array([np.nan] * 7)
489  th.GetStats(stats_values)
490 
491  sum_w = stats_values[0]
492  # sum_w2 = stats_values[1]
493  sum_wx = stats_values[2]
494  sum_wx2 = stats_values[3]
495  sum_wy = stats_values[4] # Only for TH2 and TProfile
496  sum_wy2 = stats_values[5] # Only for TH2 and TProfile
497  sum_wxy = stats_values[6] # Only for TH2
498 
499  if np.isnan(sum_wy):
500  # Only one dimensional
501  stats["x avg"] = np.divide(sum_wx, sum_w)
502  stats["x std"] = np.divide(np.sqrt(sum_wx2 * sum_w - sum_wx * sum_wx), sum_w)
503 
504  else:
505  # Only two dimensional
506  stats["x avg"] = np.divide(sum_wx, sum_w)
507  stats["x std"] = np.divide(np.sqrt(sum_wx2 * sum_w - sum_wx * sum_wx), sum_w)
508  stats["y avg"] = np.divide(sum_wy, sum_w)
509  stats["y std"] = np.divide(np.sqrt(sum_wy2 * sum_w - sum_wy * sum_wy), sum_w)
510 
511  if not np.isnan(sum_wxy):
512  stats["cov"] = np.divide((sum_wxy * sum_w - sum_wx * sum_wy), (sum_w * sum_w))
513  stats["corr"] = np.divide(stats["cov"], (stats["x std"] * stats["y std"]))
514 
515  return stats
516 
517 
518 def compose_stats_label(title, additional_stats=None):
519  """Render the summary statistics to a label string."""
520  if additional_stats is None:
521  additional_stats = {}
522  labeled_value_template = "{0:<9}: {1:.3g}"
523  labeled_string_template = "{0:<9}: {1:>9s}"
524  label_elements = []
525  if title:
526  label_elements.append(str(title))
527 
528  for key, value in list(additional_stats.items()):
529  if isinstance(value, str):
530  label_element = labeled_string_template.format(key, value)
531  else:
532  label_element = labeled_value_template.format(key, value)
533  label_elements.append(label_element)
534 
535  return "\n".join(label_elements)
536 
537 
538 def plot_tgraph_data_into(ax,
539  tgraph,
540  plot_errors=None,
541  label=None,
542  clip_to_data=True):
543  """Plot a ROOT TGraph into a matplotlib axes
544 
545  Parameters
546  ----------
547  ax : matplotlib.axes.Axes
548  An axes space in which to plot
549  tgraph : ROOT.TGraph
550  A plotable one dimensional ROOT histogram
551  plot_errors : bool, optional
552  Plot graph as errorbar plot. Default None means True for TGraphErrors and False else.
553  label : str
554  label to be given to the plot
555  """
556 
557  plt.autoscale(tight=clip_to_data)
558 
559  if plot_errors is None:
560  if isinstance(tgraph, ROOT.TGraphErrors):
561  plot_errors = True
562  else:
563  plot_errors = False
564 
565  x_taxis = tgraph.GetXaxis()
566  y_taxis = tgraph.GetYaxis()
567 
568  xlabel = reformat_root_latex_to_matplotlib_latex(x_taxis.GetTitle())
569  ylabel = reformat_root_latex_to_matplotlib_latex(y_taxis.GetTitle())
570 
571  n_points = tgraph.GetN()
572 
573  xs = np.ndarray((n_points,), dtype=float)
574  ys = np.ndarray((n_points,), dtype=float)
575 
576  x_lower_errors = np.ndarray((n_points,), dtype=float)
577  x_upper_errors = np.ndarray((n_points,), dtype=float)
578 
579  y_lower_errors = np.ndarray((n_points,), dtype=float)
580  y_upper_errors = np.ndarray((n_points,), dtype=float)
581 
582  x = ctypes.c_double()
583  y = ctypes.c_double()
584 
585  for i_point in range(n_points):
586  tgraph.GetPoint(i_point, x, y)
587  xs[i_point] = float(x.value)
588  ys[i_point] = float(y.value)
589 
590  x_lower_errors[i_point] = tgraph.GetErrorXlow(i_point)
591  x_upper_errors[i_point] = tgraph.GetErrorXhigh(i_point)
592 
593  y_lower_errors[i_point] = tgraph.GetErrorYlow(i_point)
594  y_upper_errors[i_point] = tgraph.GetErrorYhigh(i_point)
595 
596  if plot_errors:
597  root_color_index = tgraph.GetLineColor()
598  linecolor = root_color_to_matplot_color(root_color_index)
599 
600  ax.errorbar(xs,
601  ys,
602  xerr=[x_lower_errors, x_upper_errors],
603  yerr=[y_lower_errors, y_upper_errors],
604  fmt="none",
605  ecolor=linecolor,
606  label=label)
607 
608  else:
609  root_color_index = tgraph.GetMarkerColor()
610  markercolor = root_color_to_matplot_color(root_color_index)
611 
612  ax.scatter(xs,
613  ys,
614  c=markercolor,
615  s=2,
616  marker="+",
617  label=label)
618  x_lower_bound, x_upper_bound = ax.get_xlim()
619  x_lower_bound = min(x_lower_bound, np.nanmin(xs))
620 
621  ax.set_xlabel(xlabel)
622  ax.set_ylabel(ylabel)
623 
624  plt.autoscale(tight=None)
625 
626 
627 def plot_th1_data_into(ax,
628  th1,
629  plot_errors=None,
630  label=None):
631  """Plot a ROOT histogram into a matplotlib axes
632 
633  Parameters
634  ----------
635  ax : matplotlib.axes.Axes
636  An axes space in which to plot
637  th1 : ROOT.TH1
638  A plotable one dimensional ROOT histogram
639  plot_errors : bool, optional
640  Plot histogram as errorbar plot. Default None means True for TProfile and False else.
641  label : str, optional
642  label to be given to the histogram
643  """
644 
645  if plot_errors is None:
646  if isinstance(th1, ROOT.TProfile):
647  plot_errors = True
648  else:
649  plot_errors = th1.GetSumw2N() != 0
650 
651  # Bin content
652  x_taxis = th1.GetXaxis()
653  n_bins = x_taxis.GetNbins()
654 
655  bin_ids_with_underflow = list(range(n_bins + 1))
656  bin_ids_without_underflow = list(range(1, n_bins + 1))
657 
658  # Get the n_bins + 1 bin edges starting from the underflow bin 0
659  bin_edges = np.array([x_taxis.GetBinUpEdge(i_bin) for i_bin in bin_ids_with_underflow])
660 
661  # Bin center and width not including the underflow
662  bin_centers = np.array([x_taxis.GetBinCenter(i_bin) for i_bin in bin_ids_without_underflow])
663  bin_widths = np.array([x_taxis.GetBinWidth(i_bin) for i_bin in bin_ids_without_underflow])
664  bin_x_errors = bin_widths / 2.0
665 
666  # Now for the histogram content not including the underflow
667  bin_contents = np.array([th1.GetBinContent(i_bin) for i_bin in bin_ids_without_underflow])
668  bin_y_errors = np.array([th1.GetBinError(i_bin) for i_bin in bin_ids_without_underflow])
669  bin_y_upper_errors = np.array([th1.GetBinErrorUp(i_bin) for i_bin in bin_ids_without_underflow])
670  bin_y_lower_errors = np.array([th1.GetBinErrorLow(i_bin) for i_bin in bin_ids_without_underflow])
671 
672  empty_bins = (bin_contents == 0) & (bin_y_errors == 0)
673 
674  is_discrete_binning = bool(x_taxis.GetLabels())
675  bin_labels = [x_taxis.GetBinLabel(i_bin) for i_bin in bin_ids_without_underflow]
676 
677  xlabel = reformat_root_latex_to_matplotlib_latex(x_taxis.GetTitle())
678 
679  y_taxis = th1.GetYaxis()
680  ylabel = reformat_root_latex_to_matplotlib_latex(y_taxis.GetTitle())
681 
682  # May set these from th1 properties
683  y_log_scale = False
684  histtype = "step"
685 
686  root_color_index = th1.GetLineColor()
687  linecolor = root_color_to_matplot_color(root_color_index)
688 
689  if plot_errors:
690  ax.errorbar(bin_centers[~empty_bins],
691  bin_contents[~empty_bins],
692  yerr=[bin_y_lower_errors[~empty_bins],
693  bin_y_upper_errors[~empty_bins]],
694  xerr=bin_x_errors[~empty_bins],
695  fmt="none",
696  ecolor=linecolor,
697  label=label)
698 
699  y_lower_bound, y_upper_bound = common_bounds(
700  ax.get_ylim(),
701  (th1.GetMinimum(flt_min), th1.GetMaximum(flt_max))
702  )
703 
704  y_total_width = y_upper_bound - y_lower_bound
705  ax.set_ylim(y_lower_bound - 0.02 * y_total_width, y_upper_bound + 0.02 * y_total_width)
706 
707  else:
708  if is_discrete_binning:
709  ax.bar(bin_centers - 0.4,
710  bin_contents,
711  width=0.8,
712  label=label,
713  edgecolor=linecolor,
714  color=(1, 1, 1, 0), # fill transparent white
715  log=y_log_scale)
716  else:
717  ax.hist(bin_centers,
718  bins=bin_edges,
719  weights=bin_contents,
720  edgecolor=linecolor,
721  histtype=histtype,
722  label=label,
723  log=y_log_scale)
724 
725  # Fixing that the limits sometimes clip in the y direction
726  ax.set_ylim(0, 1.02 * max(bin_contents))
727 
728  if is_discrete_binning:
729  ax.set_xticks(bin_centers)
730  ax.set_xticklabels(bin_labels, rotation=0)
731 
732  total_width = bin_edges[-1] - bin_edges[0]
733  if total_width != 0:
734  ax.set_xlim(bin_edges[0] - 0.01 * total_width, bin_edges[-1] + 0.01 * total_width)
735 
736  ax.set_xlabel(xlabel)
737  ax.set_ylabel(ylabel)
738 
739 
740 def plot_th2_data_into(ax,
741  th2,
742  plot_3d=False,
743  label=None):
744  """Plot a ROOT histogram into a matplotlib axes
745 
746  Parameters
747  ----------
748  ax : matplotlib.axes.Axes
749  An axes space in which to plot
750  th2 : ROOT.TH2
751  A plotable two dimensional ROOT histogram
752  plot_3d : bool, optional
753  Plot as a three dimensional plot
754  label : str, optional
755  label to be given to the histogram
756  """
757 
758  # Bin content
759  x_taxis = th2.GetXaxis()
760  y_taxis = th2.GetYaxis()
761 
762  x_n_bins = x_taxis.GetNbins()
763  y_n_bins = y_taxis.GetNbins()
764 
765  x_bin_ids_with_underflow = list(range(x_n_bins + 1))
766  y_bin_ids_with_underflow = list(range(y_n_bins + 1))
767 
768  x_bin_ids_without_underflow = list(range(1, x_n_bins + 1))
769  y_bin_ids_without_underflow = list(range(1, y_n_bins + 1))
770 
771  # Get the n_bins + 1 bin edges starting from the underflow bin 0
772  x_bin_edges = np.array([x_taxis.GetBinUpEdge(i_bin) for i_bin in x_bin_ids_with_underflow])
773  y_bin_edges = np.array([y_taxis.GetBinUpEdge(i_bin) for i_bin in y_bin_ids_with_underflow])
774 
775  # Bin center and width not including the underflow
776  x_bin_centers = np.array([x_taxis.GetBinCenter(i_bin) for i_bin in x_bin_ids_without_underflow])
777  y_bin_centers = np.array([y_taxis.GetBinCenter(i_bin) for i_bin in y_bin_ids_without_underflow])
778 
779  x_centers, y_centers = np.meshgrid(x_bin_centers, y_bin_centers)
780 
781  x_ids, y_ids = np.meshgrid(x_bin_ids_without_underflow, y_bin_ids_without_underflow)
782 
783  bin_contents = np.array([th2.GetBinContent(int(x_i_bin), int(y_i_bin))
784  for x_i_bin, y_i_bin in zip(x_ids.flat, y_ids.flat)])
785 
786  x_is_discrete_binning = bool(x_taxis.GetLabels())
787  x_bin_labels = [x_taxis.GetBinLabel(i_bin) for i_bin in x_bin_ids_without_underflow]
788 
789  y_is_discrete_binning = bool(y_taxis.GetLabels())
790  y_bin_labels = [y_taxis.GetBinLabel(i_bin) for i_bin in y_bin_ids_without_underflow]
791 
792  xlabel = reformat_root_latex_to_matplotlib_latex(x_taxis.GetTitle())
793  ylabel = reformat_root_latex_to_matplotlib_latex(y_taxis.GetTitle())
794 
795  # May set these from th2 properties
796  log_scale = False
797 
798  # root_color_index = th2.GetLineColor()
799  # linecolor = root_color_to_matplot_color(root_color_index)
800 
801  if plot_3d:
802  raise NotImplementedError("3D plotting of two dimensional histograms not implemented yet")
803 
804  else:
805  if log_scale:
806  _, _, _, colorbar_mappable = ax.hist2d(x_centers.flatten(),
807  y_centers.flatten(),
808  weights=bin_contents,
809  bins=[x_bin_edges, y_bin_edges],
810  label=label,
811  norm=matplotlib.colors.LogNorm())
812  else:
813  _, _, _, colorbar_mappable = ax.hist2d(x_centers.flatten(),
814  y_centers.flatten(),
815  weights=bin_contents,
816  bins=[x_bin_edges, y_bin_edges],
817  label=label)
818 
819  lowest_color = colorbar_mappable.get_cmap()(0)
820 
821  # Dummy artist to show in legend
822  ax.fill(0, 0, "-", color=lowest_color, label=label)
823 
824  # For colorbar on the left
825  # cbar_ax, _ = matplotlib.colorbar.make_axes(ax,
826  # location="left",
827  # )
828 
829  colorbar_ax, _ = matplotlib.colorbar.make_axes(ax,
830  pad=0.02,
831  shrink=0.5,
832  anchor=(0.0, 1.0),
833  panchor=(1.0, 1.0),
834  aspect=10,
835  )
836 
837  matplotlib.colorbar.Colorbar(colorbar_ax, colorbar_mappable)
838 
839  if x_is_discrete_binning:
840  ax.set_xticks(x_bin_centers)
841  ax.set_xticklabels(x_bin_labels, rotation=0)
842 
843  if y_is_discrete_binning:
844  ax.set_yticks(y_bin_centers)
845  ax.set_yticklabels(y_bin_labels, rotation=0)
846 
847  ax.set_xlabel(xlabel)
848  ax.set_ylabel(ylabel)
849 
850 
851 def plot_tf1_data_into(ax,
852  tf1,
853  label=None):
854  """Plots the data of the tf1 into a matplotlib axes
855 
856  Parameters
857  ----------
858  ax : matplotlib.axes.Axes
859  An axes space in which to plot
860  tf1 : ROOT.TF1
861  Function to be ploted.
862  label : str, optional
863  Label for the legend entry.
864  """
865 
866  lower_bound = tf1.GetXmin()
867  upper_bound = tf1.GetXmin()
868  n_plot_points = max(tf1.GetNpx(), 100)
869 
870  if lower_bound == upper_bound:
871  lower_bound, upper_bound = ax.get_xlim()
872 
873  xs = np.linspace(lower_bound, upper_bound, n_plot_points)
874  ys = [tf1.Eval(x) for x in xs]
875 
876  root_color_index = tf1.GetLineColor()
877  linecolor = root_color_to_matplot_color(root_color_index)
878 
879  if any(y != 0 for y in ys):
880  line_handles = ax.plot(xs, ys, color=linecolor, label=label)
881  return line_handles
882 
883 
884 def root_color_to_matplot_color(root_color_index):
885  """Translates ROOT color into an RGB tuple.
886 
887  Parameters
888  ----------
889  root_color_index : int
890  Index of a color as defined in ROOT
891 
892  Returns
893  -------
894  (float, float, float)
895  tuple of floats that represent to RGB color fractions.
896  """
897  tcolor = ROOT.gROOT.GetColor(root_color_index)
898  return (tcolor.GetRed(), tcolor.GetGreen(), tcolor.GetBlue())
899 
900 
901 def reformat_root_latex_to_matplotlib_latex(text):
902  """Takes text that may contain ROOT pseudo latex directives and
903  translate it in to proper latex that can be understood by matplotlib"""
904 
905  # Dumb implementation, can be improved a lot
906  # Splits by white space and try to treat every part separatly.
907  # Additionally a dump regular expression replacement for ROOT latex directives
908  # is applied.
909  # It may loose some context,
910  # but I could not bother to implement a parser for ROOTs latex dialect.
911 
912  text_parts = text.split(" ")
913 
914  reformatted_text_parts = []
915 
916  for text_part in text_parts:
917  if is_root_latex_directive(text_part):
918  # All directive are wrapped into math mode
919  reformatted_text_part = r'$' + text_part.replace('#', '\\') + r'$'
920  # print 'Format' , text_part ,'to' , reformatted_text_part
921  else:
922  reformatted_text_part = text_part
923 
924  reformatted_text_part = re.sub("#([a-zA-Z_{}]*)", r"$\\\1$", reformatted_text_part)
925 
926  reformatted_text_parts.append(reformatted_text_part)
927 
928  return " ".join(reformatted_text_parts)
929 
930 
931 def is_root_latex_directive(text_part):
932  """Test if a text part looks like a ROOT latex directive"""
933  return text_part.startswith('#') or '_{' in text_part or '{}' in text_part
934 
935 
936 def common_bounds(matplot_bounds, root_bounds):
937  """Assign the common lower and upper bounds for a plot"""
938  lower_bound, upper_bound = matplot_bounds
939  root_lower_bound, root_upper_bound = root_bounds
940 
941  if root_lower_bound != 0 or root_upper_bound != 0:
942  lower_bound = np.nanmin((lower_bound, root_lower_bound))
943  upper_bound = np.nanmax((upper_bound, root_upper_bound))
944 
945  return lower_bound, upper_bound
Definition: plot.py:1