1 """Tools to plot ROOT object into matplotlib"""
9 np.seterr(invalid=
'ignore')
13 flt_max = sys.float_info.max
15 flt_min = sys.float_info.min
18 from .plot
import ValidationPlot
24 """Getter for the logger instance of this file."""
25 return logging.getLogger(__name__)
33 import matplotlib.pyplot
as plt
34 import matplotlib.transforms
as transforms
36 raise ImportError(
"matplotlib is not installed in your basf2 environment. "
37 "You may install it with 'pip install matplotlib'")
41 """Default values of the plotting options"""
61 """Indicates if a module can be plotted with matplotlib using this module."""
62 return isinstance(tobject, plotable_classes)
66 """Plot the given plotable TObject.
70 tobject : ROOT.TObject
72 legend : bool, optional
73 Create a by-side legend containing statistical information.
74 style : list(str), optional
75 List of matplotlib styles to be used for the plotting.
79 matplotlib.figure.Figure
80 The figure containing the plot
82 if isinstance(tobject, ROOT.TH2):
85 elif isinstance(tobject, ROOT.THStack):
88 elif isinstance(tobject, ROOT.TMultiGraph):
91 elif isinstance(tobject, ROOT.TGraph):
94 elif isinstance(tobject, (ROOT.TProfile, ROOT.TH1)):
98 raise ValueError(
"Plotting to matplot lib only supported for TH1, TProfile, and THStack.")
104 """Decorator to adjust the matplotlib style before plotting"""
105 @functools.wraps(plot_function)
106 def plot_function_with_style(*args, **kwds):
107 style = kwds.get(
"style",
None)
109 style = defaults.style
110 matplotlib.style.use(style)
112 matplotlib.rcParams[
"patch.linewidth"] = 2
113 matplotlib.rcParams[
'mathtext.fontset'] =
'custom'
114 matplotlib.rcParams[
'mathtext.rm'] =
'DejaVu Sans'
115 matplotlib.rcParams[
'mathtext.it'] =
'DejaVu Sans:italic'
116 matplotlib.rcParams[
'mathtext.cal'] =
'DejaVu Sans:italic'
117 matplotlib.rcParams[
'mathtext.bf'] =
'DejaVu Sans:bold'
120 matplotlib.rcParams[
'font.sans-serif'] = [
'DejaVu Sans']
121 matplotlib.rcParams[
'font.monospace'] = [
'cmtt10']
123 return plot_function(*args, **kwds)
125 return plot_function_with_style
134 """Plots a two dimensional histogram"""
152 """Plots a stack of histograms"""
158 ths = list(thstack.GetHists())
160 if all(isinstance(th, ROOT.TH3)
for th
in ths):
161 raise ValueError(
"Cannot plot a stack of three dimensional histograms")
163 elif all(isinstance(th, ROOT.TH2)
and not isinstance(th, ROOT.TH3)
for th
in ths):
164 raise NotImplementedError(
"Cannot plot a stack of two dimensional histograms")
166 elif all(isinstance(th, ROOT.TH1)
and not isinstance(th, (ROOT.TH3, ROOT.TH2))
for th
in ths):
174 max_bin_content = thstack.GetMaximum(
"nostack")
175 ax.set_ylim(0, 1.02 * max_bin_content)
178 ValueError(
"Stack of histograms with mismatching dimensions")
190 """Plots multiple overlayed graphs"""
195 for tgraph
in tmultigraph.GetListOfGraphs():
201 (tmultigraph.GetMinimum(), tmultigraph.GetMaximum())
204 ax.set_ylim(y_lower_bound, y_upper_bound)
216 """Plots a one dimensional histogram including the fit function if present"""
223 if th1.GetSumOfWeights() == 0:
224 get_logger().info(
"Skipping empty histogram %s", th1.GetName())
236 put_legend(fig, ax, legend=legend, top_handles=fit_handles, top_title=
"Fit")
245 """Plots graph including the fit function if present"""
259 put_legend(fig, ax, legend=legend, top_handles=fit_handles, top_title=
"Fit")
264 """Create a new figure already making space for a by side legend if requested
268 (matplotlib.figure.Figure, matplotlib.axes.Axes)
269 A prepared figure and axes into which can be plotted.
272 ax = fig.add_subplot(111)
275 ax.set_title(title, y=1.04)
281 """Extract the title from the plotable ROOT object and translate to ROOT latex"""
286 """Create a label from the plotable object incorporating available summary statistics."""
288 label = defaults.label
291 if isinstance(th_or_tgraph, ROOT.TH1):
296 elif isinstance(th_or_tgraph, ROOT.TGraph):
297 tgraph = th_or_tgraph
301 elif isinstance(th_or_tgraph, ROOT.TF1):
307 raise ValueError(
"Can only create a label from a ROOT.TH or ROOT.TGraph")
319 """Put the legend of the plot
321 Put one legend to right beside the axes space for some plot handles if desired.
322 Put one legend at the bottom for the remaining plot handles.
326 legend = defaults.legend
329 fig_width = fig.get_figwidth()
331 fig.set_figwidth(1.0 * 4.0 / 3.0 * fig_width)
334 box = ax.get_position()
335 ax.set_position([box.x0, box.y0, box.width * 0.75, box.height])
339 exclude_handles=top_handles,
340 force_add_legend=
True,
343 select_handles=top_handles,
348 exclude_handles=top_handles,
357 force_add_legend=False,
359 """Put a legned right beside the axes space"""
361 if not select_handles:
362 select_handles, _ = ax.get_legend_handles_labels()
365 select_handles = [handle
for handle
in select_handles
if handle
not in exclude_handles]
367 fig = ax.get_figure()
379 legend_artist = ax.legend(handles=select_handles,
380 bbox_to_anchor=(1.02, 0),
383 prop={
"family":
"monospace"},
394 legend_artist = ax.legend(handles=select_handles,
395 bbox_to_anchor=(1.02, 1),
398 prop={
"family":
"monospace"},
402 ax.add_artist(legend_artist)
406 """Retrieve a potential fit function form the plotable object"""
407 for tobject
in th1_or_tgraph.GetListOfFunctions():
408 if isinstance(tobject, ROOT.TF1):
413 """Retrieve the fitted parameters explicitly excluding fixed parameters.
415 Fixed parameters are usually additional stats entries that are already shown in the main
418 parameters = collections.OrderedDict()
420 formula = tf1.GetTitle()
421 parameters[
"formula"] = formula
423 n_parameters = tf1.GetNpar()
424 for i_parameter
in range(n_parameters):
426 lower_bound = ROOT.Double()
427 upper_bound = ROOT.Double()
428 tf1.GetParLimits(i_parameter, lower_bound, upper_bound)
430 name = tf1.GetParName(i_parameter)
431 value = tf1.GetParameter(i_parameter)
433 if lower_bound == upper_bound
and lower_bound != 0:
437 parameters[name] = value
443 """Get the summary statistics from the graph"""
444 stats = collections.OrderedDict()
446 additional_stats = ValidationPlot.get_additional_stats(tgraph)
448 for key, value
in list(additional_stats.items()):
455 """Get the summary statistics from the histogram"""
456 stats = collections.OrderedDict()
458 stats[
"count"] = th.GetEntries()
460 additional_stats = ValidationPlot.get_additional_stats(th)
462 for key, value
in list(additional_stats.items()):
465 if not isinstance(th, (ROOT.TH2, ROOT.TH3)):
466 x_taxis = th.GetXaxis()
467 n_bins = x_taxis.GetNbins()
469 if isinstance(th, ROOT.TProfile):
470 underflow_content = th.GetBinEntries(0)
471 overflow_content = th.GetBinEntries(n_bins + 1)
473 underflow_content = th.GetBinContent(0)
474 overflow_content = th.GetBinContent(n_bins + 1)
476 if underflow_content:
477 stats[
"x underf."] = underflow_content
480 stats[
"x overf."] = overflow_content
482 stats_values = np.array([np.nan] * 7)
483 th.GetStats(stats_values)
485 sum_w = stats_values[0]
486 sum_w2 = stats_values[1]
487 sum_wx = stats_values[2]
488 sum_wx2 = stats_values[3]
489 sum_wy = stats_values[4]
490 sum_wy2 = stats_values[5]
491 sum_wxy = stats_values[6]
495 stats[
"x avg"] = np.divide(sum_wx, sum_w)
496 stats[
"x std"] = np.divide(np.sqrt(sum_wx2 * sum_w - sum_wx * sum_wx), sum_w)
500 stats[
"x avg"] = np.divide(sum_wx, sum_w)
501 stats[
"x std"] = np.divide(np.sqrt(sum_wx2 * sum_w - sum_wx * sum_wx), sum_w)
502 stats[
"y avg"] = np.divide(sum_wy, sum_w)
503 stats[
"y std"] = np.divide(np.sqrt(sum_wy2 * sum_w - sum_wy * sum_wy), sum_w)
505 if not np.isnan(sum_wxy):
506 stats[
"cov"] = np.divide((sum_wxy * sum_w - sum_wx * sum_wy), (sum_w * sum_w))
507 stats[
"corr"] = np.divide(stats[
"cov"], (stats[
"x std"] * stats[
"y std"]))
513 """Render the summary statistics to a label string."""
514 keys = list(additional_stats.keys())
515 labeled_value_template =
"{0:<9}: {1:.3g}"
516 labeled_string_template =
"{0:<9}: {1:>9s}"
519 label_elements.append(str(title))
521 for key, value
in list(additional_stats.items()):
522 if isinstance(value, str):
523 label_element = labeled_string_template.format(key, value)
525 label_element = labeled_value_template.format(key, value)
526 label_elements.append(label_element)
528 return "\n".join(label_elements)
536 """Plot a ROOT TGraph into a matplotlib axes
540 ax : matplotlib.axes.Axes
541 An axes space in which to plot
543 A plotable one dimensional ROOT histogram
544 plot_errors : bool, optional
545 Plot graph as errorbar plot. Default None means True for TGraphErrors and False else.
547 label to be given to the plot
550 plt.autoscale(tight=clip_to_data)
552 if plot_errors
is None:
553 if isinstance(tgraph, ROOT.TGraphErrors):
558 x_taxis = tgraph.GetXaxis()
559 y_taxis = tgraph.GetYaxis()
564 n_points = tgraph.GetN()
566 xs = np.ndarray((n_points,), dtype=float)
567 ys = np.ndarray((n_points,), dtype=float)
569 x_lower_errors = np.ndarray((n_points,), dtype=float)
570 x_upper_errors = np.ndarray((n_points,), dtype=float)
572 y_lower_errors = np.ndarray((n_points,), dtype=float)
573 y_upper_errors = np.ndarray((n_points,), dtype=float)
578 for i_point
in range(n_points):
579 tgraph.GetPoint(i_point, x, y)
580 xs[i_point] = float(x)
581 ys[i_point] = float(y)
583 x_lower_errors[i_point] = tgraph.GetErrorXlow(i_point)
584 x_upper_errors[i_point] = tgraph.GetErrorXhigh(i_point)
586 y_lower_errors[i_point] = tgraph.GetErrorYlow(i_point)
587 y_upper_errors[i_point] = tgraph.GetErrorYhigh(i_point)
590 root_color_index = tgraph.GetLineColor()
595 xerr=[x_lower_errors, x_upper_errors],
596 yerr=[y_lower_errors, y_upper_errors],
602 root_color_index = tgraph.GetMarkerColor()
611 x_lower_bound, x_upper_bound = ax.get_xlim()
612 x_lower_bound = min(x_lower_bound, np.nanmin(xs))
614 ax.set_xlabel(xlabel)
615 ax.set_ylabel(ylabel)
617 plt.autoscale(tight=
None)
624 """Plot a ROOT histogram into a matplotlib axes
628 ax : matplotlib.axes.Axes
629 An axes space in which to plot
631 A plotable one dimensional ROOT histogram
632 plot_errors : bool, optional
633 Plot histogram as errorbar plot. Default None means True for TProfile and False else.
634 label : str, optional
635 label to be given to the histogram
638 if plot_errors
is None:
639 if isinstance(th1, ROOT.TProfile):
642 plot_errors = th1.GetSumw2N() != 0
645 x_taxis = th1.GetXaxis()
646 n_bins = x_taxis.GetNbins()
648 bin_ids_with_underflow = list(range(n_bins + 1))
649 bin_ids_without_underflow = list(range(1, n_bins + 1))
652 bin_edges = np.array([x_taxis.GetBinUpEdge(i_bin)
for i_bin
in bin_ids_with_underflow])
655 bin_centers = np.array([x_taxis.GetBinCenter(i_bin)
for i_bin
in bin_ids_without_underflow])
656 bin_widths = np.array([x_taxis.GetBinWidth(i_bin)
for i_bin
in bin_ids_without_underflow])
657 bin_x_errors = bin_widths / 2.0
660 bin_contents = np.array([th1.GetBinContent(i_bin)
for i_bin
in bin_ids_without_underflow])
661 bin_y_errors = np.array([th1.GetBinError(i_bin)
for i_bin
in bin_ids_without_underflow])
662 bin_y_upper_errors = np.array([th1.GetBinErrorUp(i_bin)
for i_bin
in bin_ids_without_underflow])
663 bin_y_lower_errors = np.array([th1.GetBinErrorLow(i_bin)
for i_bin
in bin_ids_without_underflow])
665 empty_bins = (bin_contents == 0) & (bin_y_errors == 0)
667 is_discrete_binning = bool(x_taxis.GetLabels())
668 bin_labels = [x_taxis.GetBinLabel(i_bin)
for i_bin
in bin_ids_without_underflow]
672 y_taxis = th1.GetYaxis()
679 root_color_index = th1.GetLineColor()
683 ax.errorbar(bin_centers[~empty_bins],
684 bin_contents[~empty_bins],
685 yerr=[bin_y_lower_errors[~empty_bins],
686 bin_y_upper_errors[~empty_bins]],
687 xerr=bin_x_errors[~empty_bins],
694 (th1.GetMinimum(flt_min), th1.GetMaximum(flt_max))
697 y_total_width = y_upper_bound - y_lower_bound
698 ax.set_ylim(y_lower_bound - 0.02 * y_total_width, y_upper_bound + 0.02 * y_total_width)
701 if is_discrete_binning:
702 ax.bar(bin_centers - 0.4,
712 weights=bin_contents,
719 ax.set_ylim(0, 1.02 * max(bin_contents))
721 if is_discrete_binning:
722 ax.set_xticks(bin_centers)
723 ax.set_xticklabels(bin_labels, rotation=0)
725 total_width = bin_edges[-1] - bin_edges[0]
727 ax.set_xlim(bin_edges[0] - 0.01 * total_width, bin_edges[-1] + 0.01 * total_width)
729 ax.set_xlabel(xlabel)
730 ax.set_ylabel(ylabel)
737 """Plot a ROOT histogram into a matplotlib axes
741 ax : matplotlib.axes.Axes
742 An axes space in which to plot
744 A plotable two dimensional ROOT histogram
745 plot_3d : bool, optional
746 Plot as a three dimensional plot
747 label : str, optional
748 label to be given to the histogram
752 x_taxis = th2.GetXaxis()
753 y_taxis = th2.GetYaxis()
755 x_n_bins = x_taxis.GetNbins()
756 y_n_bins = y_taxis.GetNbins()
758 x_bin_ids_with_underflow = list(range(x_n_bins + 1))
759 y_bin_ids_with_underflow = list(range(y_n_bins + 1))
761 x_bin_ids_without_underflow = list(range(1, x_n_bins + 1))
762 y_bin_ids_without_underflow = list(range(1, y_n_bins + 1))
765 x_bin_edges = np.array([x_taxis.GetBinUpEdge(i_bin)
for i_bin
in x_bin_ids_with_underflow])
766 y_bin_edges = np.array([y_taxis.GetBinUpEdge(i_bin)
for i_bin
in y_bin_ids_with_underflow])
769 x_bin_centers = np.array([x_taxis.GetBinCenter(i_bin)
for i_bin
in x_bin_ids_without_underflow])
770 y_bin_centers = np.array([y_taxis.GetBinCenter(i_bin)
for i_bin
in y_bin_ids_without_underflow])
772 x_centers, y_centers = np.meshgrid(x_bin_centers, y_bin_centers)
774 x_ids, y_ids = np.meshgrid(x_bin_ids_without_underflow, y_bin_ids_without_underflow)
776 bin_contents = np.array([th2.GetBinContent(int(x_i_bin), int(y_i_bin))
777 for x_i_bin, y_i_bin
in zip(x_ids.flat, y_ids.flat)])
779 x_is_discrete_binning = bool(x_taxis.GetLabels())
780 x_bin_labels = [x_taxis.GetBinLabel(i_bin)
for i_bin
in x_bin_ids_without_underflow]
782 y_is_discrete_binning = bool(y_taxis.GetLabels())
783 y_bin_labels = [y_taxis.GetBinLabel(i_bin)
for i_bin
in y_bin_ids_without_underflow]
791 root_color_index = th2.GetLineColor()
795 raise NotImplementedError(
"3D plotting of two dimensional histograms not implemented yet")
799 _, _, _, colorbar_mappable = ax.hist2d(x_centers.flatten(),
801 weights=bin_contents,
802 bins=[x_bin_edges, y_bin_edges],
804 norm=matplotlib.colors.LogNorm())
806 _, _, _, colorbar_mappable = ax.hist2d(x_centers.flatten(),
808 weights=bin_contents,
809 bins=[x_bin_edges, y_bin_edges],
812 lowest_color = colorbar_mappable.get_cmap()(0)
815 ax.fill(0, 0,
"-", color=lowest_color, label=label)
822 colorbar_ax, _ = matplotlib.colorbar.make_axes(ax,
830 matplotlib.colorbar.Colorbar(colorbar_ax, colorbar_mappable)
832 if x_is_discrete_binning:
833 ax.set_xticks(x_bin_centers)
834 ax.set_xticklabels(x_bin_labels, rotation=0)
836 if y_is_discrete_binning:
837 ax.set_yticks(y_bin_centers)
838 ax.set_yticklabels(y_bin_labels, rotation=0)
840 ax.set_xlabel(xlabel)
841 ax.set_ylabel(ylabel)
847 """Plots the data of the tf1 into a matplotlib axes
851 ax : matplotlib.axes.Axes
852 An axes space in which to plot
854 Function to be ploted.
855 label : str, optional
856 Label for the legend entry.
859 lower_bound = tf1.GetXmin()
860 upper_bound = tf1.GetXmin()
861 n_plot_points = max(tf1.GetNpx(), 100)
863 if lower_bound == upper_bound:
864 lower_bound, upper_bound = ax.get_xlim()
866 xs = np.linspace(lower_bound, upper_bound, n_plot_points)
867 ys = [tf1.Eval(x)
for x
in xs]
869 root_color_index = tf1.GetLineColor()
872 if any(y != 0
for y
in ys):
873 line_handles = ax.plot(xs, ys, color=linecolor, label=label)
878 """Translates ROOT color into an RGB tuple.
882 root_color_index : int
883 Index of a color as defined in ROOT
887 (float, float, float)
888 tuple of floats that represent to RGB color fractions.
890 tcolor = ROOT.gROOT.GetColor(root_color_index)
891 return (tcolor.GetRed(), tcolor.GetGreen(), tcolor.GetBlue())
895 """Takes text that may contain ROOT pseudo latex directives and
896 translate it in to proper latex that can be understood by matplotlib"""
905 text_parts = text.split(
" ")
907 reformatted_text_parts = []
909 for text_part
in text_parts:
912 reformatted_text_part =
r'$' + text_part.replace(
'#',
'\\') +
r'$'
915 reformatted_text_part = text_part
917 reformatted_text_part = re.sub(
"#([a-zA-Z_{}]*)",
r"$\\\1$", reformatted_text_part)
919 reformatted_text_parts.append(reformatted_text_part)
921 return " ".join(reformatted_text_parts)
925 """Test if a text part looks like a ROOT latex directive"""
926 return text_part.startswith(
'#')
or '_{' in text_part
or '{}' in text_part
930 """Assign the common lower and upper bounds for a plot"""
931 lower_bound, upper_bound = matplot_bounds
932 root_lower_bound, root_upper_bound = root_bounds
934 if root_lower_bound != 0
or root_upper_bound != 0:
935 lower_bound = np.nanmin((lower_bound, root_lower_bound))
936 upper_bound = np.nanmax((upper_bound, root_upper_bound))
938 return lower_bound, upper_bound