9"""Tools to plot ROOT object into matplotlib"""
24np.seterr(invalid=
'ignore')
26flt_max = sys.float_info.max
28flt_min = sys.float_info.min
32 """Getter for the logger instance of this file."""
33 return logging.getLogger(__name__)
41 import matplotlib.pyplot
as plt
43 raise ImportError(
"matplotlib is not installed in your basf2 environment. "
44 "You may install it with 'pip install matplotlib'")
48 """Default values of the plotting options"""
68def is_plotable(tobject):
69 """Indicates if a module can be plotted with matplotlib using this module."""
70 return isinstance(tobject, plotable_classes)
73def plot(tobject, **kwd):
74 """Plot the given plotable TObject.
78 tobject : ROOT.TObject
80 legend : bool, optional
81 Create a by-side legend containing statistical information.
82 style : list(str), optional
83 List of matplotlib styles to be used for the plotting.
87 matplotlib.figure.Figure
88 The figure containing the plot
90 if isinstance(tobject, ROOT.TH2):
91 return plot_th2(tobject, **kwd)
93 elif isinstance(tobject, ROOT.THStack):
94 return plot_thstack(tobject, **kwd)
96 elif isinstance(tobject, ROOT.TMultiGraph):
97 return plot_tmultigraph(tobject, **kwd)
99 elif isinstance(tobject, ROOT.TGraph):
100 return plot_tgraph(tobject, **kwd)
102 elif isinstance(tobject, (ROOT.TProfile, ROOT.TH1)):
103 return plot_th1(tobject, **kwd)
106 raise ValueError(
"Plotting to matplot lib only supported for TH1, TProfile, and THStack.")
109def use_style(plot_function):
110 """Decorator to adjust the matplotlib style before plotting"""
111 @functools.wraps(plot_function)
112 def plot_function_with_style(*args, **kwds):
113 style = kwds.get(
"style",
None)
115 style = defaults.style
116 matplotlib.style.use(style)
118 matplotlib.rcParams[
"patch.linewidth"] = 2
119 matplotlib.rcParams[
'mathtext.fontset'] =
'custom'
120 matplotlib.rcParams[
'mathtext.rm'] =
'DejaVu Sans'
121 matplotlib.rcParams[
'mathtext.it'] =
'DejaVu Sans:italic'
122 matplotlib.rcParams[
'mathtext.cal'] =
'DejaVu Sans:italic'
123 matplotlib.rcParams[
'mathtext.bf'] =
'DejaVu Sans:bold'
126 matplotlib.rcParams[
'font.sans-serif'] = [
'DejaVu Sans']
127 matplotlib.rcParams[
'font.monospace'] = [
'cmtt10']
129 return plot_function(*args, **kwds)
131 return plot_function_with_style
140 """Plots a two dimensional histogram"""
142 title = reformat_root_latex_to_matplotlib_latex(th2.GetTitle())
143 fig, ax = create_figure(title=title)
145 th2_label = create_label(th2, label)
146 plot_th2_data_into(ax, th2, label=th2_label)
148 put_legend(fig, ax, legend=legend)
153def plot_thstack(thstack,
158 """Plots a stack of histograms"""
160 title = create_title(thstack)
161 fig, ax = create_figure(title=title)
164 ths = list(thstack.GetHists())
166 if all(isinstance(th, ROOT.TH3)
for th
in ths):
167 raise ValueError(
"Cannot plot a stack of three dimensional histograms")
169 elif all(isinstance(th, ROOT.TH2)
and not isinstance(th, ROOT.TH3)
for th
in ths):
170 raise NotImplementedError(
"Cannot plot a stack of two dimensional histograms")
172 elif all(isinstance(th, ROOT.TH1)
and not isinstance(th, (ROOT.TH3, ROOT.TH2))
for th
in ths):
176 th1_label = create_label(th1, label=label)
177 plot_th1_data_into(ax, th1, label=th1_label)
180 max_bin_content = thstack.GetMaximum(
"nostack")
181 ax.set_ylim(0, 1.02 * max_bin_content)
184 ValueError(
"Stack of histograms with mismatching dimensions")
186 put_legend(fig, ax, legend=legend)
191def plot_tmultigraph(tmultigraph,
196 """Plots multiple overlaid graphs"""
198 title = create_title(tmultigraph)
199 fig, ax = create_figure(title=title)
201 for tgraph
in tmultigraph.GetListOfGraphs():
202 tgraph_label = create_label(tgraph, label=label)
203 plot_tgraph_data_into(ax, tgraph, label=tgraph_label)
205 y_lower_bound, y_upper_bound = common_bounds(
207 (tmultigraph.GetMinimum(), tmultigraph.GetMaximum())
210 ax.set_ylim(y_lower_bound, y_upper_bound)
212 put_legend(fig, ax, legend=legend)
222 """Plots a one dimensional histogram including the fit function if present"""
224 title = create_title(th1)
225 fig, ax = create_figure(title=title)
227 th1_label = create_label(th1, label)
229 if th1.GetSumOfWeights() == 0:
230 get_logger().info(
"Skipping empty histogram %s", th1.GetName())
233 plot_th1_data_into(ax, th1, label=th1_label)
237 tf1_label = create_label(tf1, label=label)
238 fit_handles = plot_tf1_data_into(ax, tf1, label=tf1_label)
242 put_legend(fig, ax, legend=legend, top_handles=fit_handles, top_title=
"Fit")
247def plot_tgraph(tgraph,
251 """Plots graph including the fit function if present"""
252 title = create_title(tgraph)
253 fig, ax = create_figure(title=title)
255 tgraph_label = create_label(tgraph, label=label)
256 plot_tgraph_data_into(ax, tgraph, label=tgraph_label)
258 tf1 = get_fit(tgraph)
260 tf1_label = create_label(tf1, label=label)
261 fit_handles = plot_tf1_data_into(ax, tf1, label=tf1_label)
265 put_legend(fig, ax, legend=legend, top_handles=fit_handles, top_title=
"Fit")
269def create_figure(title=None):
270 """Create a new figure already making space for a by side legend if requested
274 (matplotlib.figure.Figure, matplotlib.axes.Axes)
275 A prepared figure and axes into which can be plotted.
278 ax = fig.add_subplot(111)
281 ax.set_title(title, y=1.04)
286def create_title(tplotable):
287 """Extract the title from the plotable ROOT object and translate to ROOT latex"""
288 return reformat_root_latex_to_matplotlib_latex(tplotable.GetTitle())
291def create_label(th_or_tgraph, label=None):
292 """Create a label from the plotable object incorporating available summary statistics."""
294 label = defaults.label
297 if isinstance(th_or_tgraph, ROOT.TH1):
299 stats = get_stats_from_th(th)
300 label = compose_stats_label(
"", stats)
302 elif isinstance(th_or_tgraph, ROOT.TGraph):
303 tgraph = th_or_tgraph
304 stats = get_stats_from_tgraph(tgraph)
305 label = compose_stats_label(
"", stats)
307 elif isinstance(th_or_tgraph, ROOT.TF1):
309 parameters = get_fit_parameters(tf1)
310 label = compose_stats_label(
"", parameters)
313 raise ValueError(
"Can only create a label from a ROOT.TH or ROOT.TGraph")
325 """Put the legend of the plot
327 Put one legend to right beside the axes space for some plot handles
if desired.
328 Put one legend at the bottom
for the remaining plot handles.
332 legend = defaults.legend
335 fig_width = fig.get_figwidth()
337 fig.set_figwidth(1.0 * 4.0 / 3.0 * fig_width)
340 box = ax.get_position()
341 ax.set_position([box.x0, box.y0, box.width * 0.75, box.height])
344 put_legend_outside(ax,
345 exclude_handles=top_handles,
346 force_add_legend=
True,
348 put_legend_outside(ax,
349 select_handles=top_handles,
353 put_legend_outside(ax,
354 exclude_handles=top_handles,
358def put_legend_outside(ax,
363 force_add_legend=False,
365 """Put a legned right beside the axes space"""
367 if not select_handles:
368 select_handles, _ = ax.get_legend_handles_labels()
371 select_handles = [handle
for handle
in select_handles
if handle
not in exclude_handles]
385 legend_artist = ax.legend(handles=select_handles,
386 bbox_to_anchor=(1.02, 0),
389 prop={
"family":
"monospace"},
400 legend_artist = ax.legend(handles=select_handles,
401 bbox_to_anchor=(1.02, 1),
404 prop={
"family":
"monospace"},
408 ax.add_artist(legend_artist)
411def get_fit(th1_or_tgraph):
412 """Retrieve a potential fit function form the plotable object"""
413 for tobject
in th1_or_tgraph.GetListOfFunctions():
414 if isinstance(tobject, ROOT.TF1):
418def get_fit_parameters(tf1):
419 """Retrieve the fitted parameters explicitly excluding fixed parameters.
421 Fixed parameters are usually additional stats entries that are already shown in the main
424 parameters = collections.OrderedDict()
426 formula = tf1.GetTitle()
427 parameters["formula"] = formula
429 n_parameters = tf1.GetNpar()
430 for i_parameter
in range(n_parameters):
432 lower_bound = ctypes.c_double()
433 upper_bound = ctypes.c_double()
434 tf1.GetParLimits(i_parameter, lower_bound, upper_bound)
436 name = tf1.GetParName(i_parameter)
437 value = tf1.GetParameter(i_parameter)
439 if lower_bound.value == upper_bound.value
and lower_bound.value != 0:
443 parameters[name] = value
448def get_stats_from_tgraph(tgraph):
449 """Get the summary statistics from the graph"""
450 stats = collections.OrderedDict()
452 additional_stats = ValidationPlot.get_additional_stats(tgraph)
454 for key, value
in list(additional_stats.items()):
460def get_stats_from_th(th):
461 """Get the summary statistics from the histogram"""
462 stats = collections.OrderedDict()
464 stats[
"count"] = th.GetEntries()
466 additional_stats = ValidationPlot.get_additional_stats(th)
468 for key, value
in list(additional_stats.items()):
471 if not isinstance(th, (ROOT.TH2, ROOT.TH3)):
472 x_taxis = th.GetXaxis()
473 n_bins = x_taxis.GetNbins()
475 if isinstance(th, ROOT.TProfile):
476 underflow_content = th.GetBinEntries(0)
477 overflow_content = th.GetBinEntries(n_bins + 1)
479 underflow_content = th.GetBinContent(0)
480 overflow_content = th.GetBinContent(n_bins + 1)
482 if underflow_content:
483 stats[
"x underf."] = underflow_content
486 stats[
"x overf."] = overflow_content
488 stats_values = np.array([np.nan] * 7)
489 th.GetStats(stats_values)
491 sum_w = stats_values[0]
493 sum_wx = stats_values[2]
494 sum_wx2 = stats_values[3]
495 sum_wy = stats_values[4]
496 sum_wy2 = stats_values[5]
497 sum_wxy = stats_values[6]
501 stats[
"x avg"] = np.divide(sum_wx, sum_w)
502 stats[
"x std"] = np.divide(np.sqrt(sum_wx2 * sum_w - sum_wx * sum_wx), sum_w)
506 stats[
"x avg"] = np.divide(sum_wx, sum_w)
507 stats[
"x std"] = np.divide(np.sqrt(sum_wx2 * sum_w - sum_wx * sum_wx), sum_w)
508 stats[
"y avg"] = np.divide(sum_wy, sum_w)
509 stats[
"y std"] = np.divide(np.sqrt(sum_wy2 * sum_w - sum_wy * sum_wy), sum_w)
511 if not np.isnan(sum_wxy):
512 stats[
"cov"] = np.divide((sum_wxy * sum_w - sum_wx * sum_wy), (sum_w * sum_w))
513 stats[
"corr"] = np.divide(stats[
"cov"], (stats[
"x std"] * stats[
"y std"]))
518def compose_stats_label(title, additional_stats=None):
519 """Render the summary statistics to a label string."""
520 if additional_stats
is None:
521 additional_stats = {}
522 labeled_value_template =
"{0:<9}: {1:.3g}"
523 labeled_string_template =
"{0:<9}: {1:>9s}"
526 label_elements.append(str(title))
528 for key, value
in list(additional_stats.items()):
529 if isinstance(value, str):
530 label_element = labeled_string_template.format(key, value)
532 label_element = labeled_value_template.format(key, value)
533 label_elements.append(label_element)
535 return "\n".join(label_elements)
538def plot_tgraph_data_into(ax,
543 """Plot a ROOT TGraph into a matplotlib axes
547 ax : matplotlib.axes.Axes
548 An axes space in which to plot
550 A plotable one dimensional ROOT histogram
551 plot_errors : bool, optional
552 Plot graph
as errorbar plot. Default
None means
True for TGraphErrors
and False else.
554 label to be given to the plot
557 plt.autoscale(tight=clip_to_data)
559 if plot_errors
is None:
560 if isinstance(tgraph, ROOT.TGraphErrors):
565 x_taxis = tgraph.GetXaxis()
566 y_taxis = tgraph.GetYaxis()
568 xlabel = reformat_root_latex_to_matplotlib_latex(x_taxis.GetTitle())
569 ylabel = reformat_root_latex_to_matplotlib_latex(y_taxis.GetTitle())
571 n_points = tgraph.GetN()
573 xs = np.ndarray((n_points,), dtype=float)
574 ys = np.ndarray((n_points,), dtype=float)
576 x_lower_errors = np.ndarray((n_points,), dtype=float)
577 x_upper_errors = np.ndarray((n_points,), dtype=float)
579 y_lower_errors = np.ndarray((n_points,), dtype=float)
580 y_upper_errors = np.ndarray((n_points,), dtype=float)
582 x = ctypes.c_double()
583 y = ctypes.c_double()
585 for i_point
in range(n_points):
586 tgraph.GetPoint(i_point, x, y)
587 xs[i_point] = float(x.value)
588 ys[i_point] = float(y.value)
590 x_lower_errors[i_point] = tgraph.GetErrorXlow(i_point)
591 x_upper_errors[i_point] = tgraph.GetErrorXhigh(i_point)
593 y_lower_errors[i_point] = tgraph.GetErrorYlow(i_point)
594 y_upper_errors[i_point] = tgraph.GetErrorYhigh(i_point)
597 root_color_index = tgraph.GetLineColor()
598 linecolor = root_color_to_matplot_color(root_color_index)
602 xerr=[x_lower_errors, x_upper_errors],
603 yerr=[y_lower_errors, y_upper_errors],
609 root_color_index = tgraph.GetMarkerColor()
610 markercolor = root_color_to_matplot_color(root_color_index)
618 x_lower_bound, x_upper_bound = ax.get_xlim()
619 x_lower_bound = min(x_lower_bound, np.nanmin(xs))
621 ax.set_xlabel(xlabel)
622 ax.set_ylabel(ylabel)
624 plt.autoscale(tight=
None)
627def plot_th1_data_into(ax,
631 """Plot a ROOT histogram into a matplotlib axes
635 ax : matplotlib.axes.Axes
636 An axes space in which to plot
638 A plotable one dimensional ROOT histogram
639 plot_errors : bool, optional
640 Plot histogram
as errorbar plot. Default
None means
True for TProfile
and False else.
641 label : str, optional
642 label to be given to the histogram
645 if plot_errors
is None:
646 if isinstance(th1, ROOT.TProfile):
649 plot_errors = th1.GetSumw2N() != 0
652 x_taxis = th1.GetXaxis()
653 n_bins = x_taxis.GetNbins()
655 bin_ids_with_underflow = list(range(n_bins + 1))
656 bin_ids_without_underflow = list(range(1, n_bins + 1))
659 bin_edges = np.array([x_taxis.GetBinUpEdge(i_bin)
for i_bin
in bin_ids_with_underflow])
662 bin_centers = np.array([x_taxis.GetBinCenter(i_bin)
for i_bin
in bin_ids_without_underflow])
663 bin_widths = np.array([x_taxis.GetBinWidth(i_bin)
for i_bin
in bin_ids_without_underflow])
664 bin_x_errors = bin_widths / 2.0
667 bin_contents = np.array([th1.GetBinContent(i_bin)
for i_bin
in bin_ids_without_underflow])
668 bin_y_errors = np.array([th1.GetBinError(i_bin)
for i_bin
in bin_ids_without_underflow])
669 bin_y_upper_errors = np.array([th1.GetBinErrorUp(i_bin)
for i_bin
in bin_ids_without_underflow])
670 bin_y_lower_errors = np.array([th1.GetBinErrorLow(i_bin)
for i_bin
in bin_ids_without_underflow])
672 empty_bins = (bin_contents == 0) & (bin_y_errors == 0)
674 is_discrete_binning = bool(x_taxis.GetLabels())
675 bin_labels = [x_taxis.GetBinLabel(i_bin)
for i_bin
in bin_ids_without_underflow]
677 xlabel = reformat_root_latex_to_matplotlib_latex(x_taxis.GetTitle())
679 y_taxis = th1.GetYaxis()
680 ylabel = reformat_root_latex_to_matplotlib_latex(y_taxis.GetTitle())
686 root_color_index = th1.GetLineColor()
687 linecolor = root_color_to_matplot_color(root_color_index)
690 ax.errorbar(bin_centers[~empty_bins],
691 bin_contents[~empty_bins],
692 yerr=[bin_y_lower_errors[~empty_bins],
693 bin_y_upper_errors[~empty_bins]],
694 xerr=bin_x_errors[~empty_bins],
699 y_lower_bound, y_upper_bound = common_bounds(
701 (th1.GetMinimum(flt_min), th1.GetMaximum(flt_max))
704 y_total_width = y_upper_bound - y_lower_bound
705 ax.set_ylim(y_lower_bound - 0.02 * y_total_width, y_upper_bound + 0.02 * y_total_width)
708 if is_discrete_binning:
709 ax.bar(bin_centers - 0.4,
719 weights=bin_contents,
726 ax.set_ylim(0, 1.02 * max(bin_contents))
728 if is_discrete_binning:
729 ax.set_xticks(bin_centers)
730 ax.set_xticklabels(bin_labels, rotation=0)
732 total_width = bin_edges[-1] - bin_edges[0]
734 ax.set_xlim(bin_edges[0] - 0.01 * total_width, bin_edges[-1] + 0.01 * total_width)
736 ax.set_xlabel(xlabel)
737 ax.set_ylabel(ylabel)
740def plot_th2_data_into(ax,
744 """Plot a ROOT histogram into a matplotlib axes
748 ax : matplotlib.axes.Axes
749 An axes space in which to plot
751 A plotable two dimensional ROOT histogram
752 plot_3d : bool, optional
753 Plot
as a three dimensional plot
754 label : str, optional
755 label to be given to the histogram
759 x_taxis = th2.GetXaxis()
760 y_taxis = th2.GetYaxis()
762 x_n_bins = x_taxis.GetNbins()
763 y_n_bins = y_taxis.GetNbins()
765 x_bin_ids_with_underflow = list(range(x_n_bins + 1))
766 y_bin_ids_with_underflow = list(range(y_n_bins + 1))
768 x_bin_ids_without_underflow = list(range(1, x_n_bins + 1))
769 y_bin_ids_without_underflow = list(range(1, y_n_bins + 1))
772 x_bin_edges = np.array([x_taxis.GetBinUpEdge(i_bin)
for i_bin
in x_bin_ids_with_underflow])
773 y_bin_edges = np.array([y_taxis.GetBinUpEdge(i_bin)
for i_bin
in y_bin_ids_with_underflow])
776 x_bin_centers = np.array([x_taxis.GetBinCenter(i_bin)
for i_bin
in x_bin_ids_without_underflow])
777 y_bin_centers = np.array([y_taxis.GetBinCenter(i_bin)
for i_bin
in y_bin_ids_without_underflow])
779 x_centers, y_centers = np.meshgrid(x_bin_centers, y_bin_centers)
781 x_ids, y_ids = np.meshgrid(x_bin_ids_without_underflow, y_bin_ids_without_underflow)
783 bin_contents = np.array([th2.GetBinContent(int(x_i_bin), int(y_i_bin))
784 for x_i_bin, y_i_bin
in zip(x_ids.flat, y_ids.flat)])
786 x_is_discrete_binning = bool(x_taxis.GetLabels())
787 x_bin_labels = [x_taxis.GetBinLabel(i_bin)
for i_bin
in x_bin_ids_without_underflow]
789 y_is_discrete_binning = bool(y_taxis.GetLabels())
790 y_bin_labels = [y_taxis.GetBinLabel(i_bin)
for i_bin
in y_bin_ids_without_underflow]
792 xlabel = reformat_root_latex_to_matplotlib_latex(x_taxis.GetTitle())
793 ylabel = reformat_root_latex_to_matplotlib_latex(y_taxis.GetTitle())
802 raise NotImplementedError(
"3D plotting of two dimensional histograms not implemented yet")
806 _, _, _, colorbar_mappable = ax.hist2d(x_centers.flatten(),
808 weights=bin_contents,
809 bins=[x_bin_edges, y_bin_edges],
811 norm=matplotlib.colors.LogNorm())
813 _, _, _, colorbar_mappable = ax.hist2d(x_centers.flatten(),
815 weights=bin_contents,
816 bins=[x_bin_edges, y_bin_edges],
819 lowest_color = colorbar_mappable.get_cmap()(0)
822 ax.fill(0, 0,
"-", color=lowest_color, label=label)
829 colorbar_ax, _ = matplotlib.colorbar.make_axes(ax,
837 matplotlib.colorbar.Colorbar(colorbar_ax, colorbar_mappable)
839 if x_is_discrete_binning:
840 ax.set_xticks(x_bin_centers)
841 ax.set_xticklabels(x_bin_labels, rotation=0)
843 if y_is_discrete_binning:
844 ax.set_yticks(y_bin_centers)
845 ax.set_yticklabels(y_bin_labels, rotation=0)
847 ax.set_xlabel(xlabel)
848 ax.set_ylabel(ylabel)
851def plot_tf1_data_into(ax,
854 """Plots the data of the tf1 into a matplotlib axes
858 ax : matplotlib.axes.Axes
859 An axes space in which to plot
861 Function to be plotted.
862 label : str, optional
863 Label
for the legend entry.
866 lower_bound = tf1.GetXmin()
867 upper_bound = tf1.GetXmin()
868 n_plot_points = max(tf1.GetNpx(), 100)
870 if lower_bound == upper_bound:
871 lower_bound, upper_bound = ax.get_xlim()
873 xs = np.linspace(lower_bound, upper_bound, n_plot_points)
874 ys = [tf1.Eval(x)
for x
in xs]
876 root_color_index = tf1.GetLineColor()
877 linecolor = root_color_to_matplot_color(root_color_index)
879 if any(y != 0
for y
in ys):
880 line_handles = ax.plot(xs, ys, color=linecolor, label=label)
884def root_color_to_matplot_color(root_color_index):
885 """Translates ROOT color into an RGB tuple.
889 root_color_index : int
890 Index of a color as defined
in ROOT
894 (float, float, float)
895 tuple of floats that represent to RGB color fractions.
897 tcolor = ROOT.gROOT.GetColor(root_color_index)
898 return (tcolor.GetRed(), tcolor.GetGreen(), tcolor.GetBlue())
901def reformat_root_latex_to_matplotlib_latex(text):
902 """Takes text that may contain ROOT pseudo latex directives and
903 translate it in to proper latex that can be understood by matplotlib
"""
912 text_parts = text.split(
" ")
914 reformatted_text_parts = []
916 for text_part
in text_parts:
917 if is_root_latex_directive(text_part):
919 reformatted_text_part =
r'$' + text_part.replace(
'#',
'\\') +
r'$'
922 reformatted_text_part = text_part
924 reformatted_text_part = re.sub(
"#([a-zA-Z_{}]*)",
r"$\\\1$", reformatted_text_part)
926 reformatted_text_parts.append(reformatted_text_part)
928 return " ".join(reformatted_text_parts)
931def is_root_latex_directive(text_part):
932 """Test if a text part looks like a ROOT latex directive"""
933 return text_part.startswith(
'#')
or '_{' in text_part
or '{}' in text_part
936def common_bounds(matplot_bounds, root_bounds):
937 """Assign the common lower and upper bounds for a plot"""
938 lower_bound, upper_bound = matplot_bounds
939 root_lower_bound, root_upper_bound = root_bounds
941 if root_lower_bound != 0
or root_upper_bound != 0:
942 lower_bound = np.nanmin((lower_bound, root_lower_bound))
943 upper_bound = np.nanmax((upper_bound, root_upper_bound))
945 return lower_bound, upper_bound