Belle II Software development
matplotting.py
1
8
9"""Tools to plot ROOT object into matplotlib"""
10
11import ROOT
12
13import re
14import functools
15import numpy as np
16import collections
17from tracking.validation.plot import ValidationPlot
18import ctypes
19
20import logging
21
22import sys
23
24np.seterr(invalid='ignore')
25
26flt_max = sys.float_info.max
27
28flt_min = sys.float_info.min
29
30
31def get_logger():
32 """Getter for the logger instance of this file."""
33 return logging.getLogger(__name__)
34
35
36try:
37 import matplotlib
38 # Switch to noninteractive backend
39 matplotlib.use('Agg')
40
41 import matplotlib.pyplot as plt
42except ImportError:
43 raise ImportError("matplotlib is not installed in your basf2 environment. "
44 "You may install it with 'pip install matplotlib'")
45
46
48 """Default values of the plotting options"""
49
50 style = "bmh"
51
52 legend = True
53
54 label = True
55
56
57
58plotable_classes = (
59 ROOT.TH1,
60 ROOT.TH2,
61 ROOT.THStack,
62 ROOT.TGraph,
63 ROOT.TGraphErrors,
64 ROOT.TMultiGraph
65)
66
67
68def is_plotable(tobject):
69 """Indicates if a module can be plotted with matplotlib using this module."""
70 return isinstance(tobject, plotable_classes)
71
72
73def plot(tobject, **kwd):
74 """Plot the given plotable TObject.
75
76 Parameters
77 ----------
78 tobject : ROOT.TObject
79 Plotable TObject.
80 legend : bool, optional
81 Create a by-side legend containing statistical information.
82 style : list(str), optional
83 List of matplotlib styles to be used for the plotting.
84
85 Returns
86 -------
87 matplotlib.figure.Figure
88 The figure containing the plot
89 """
90 if isinstance(tobject, ROOT.TH2):
91 return plot_th2(tobject, **kwd)
92
93 elif isinstance(tobject, ROOT.THStack):
94 return plot_thstack(tobject, **kwd)
95
96 elif isinstance(tobject, ROOT.TMultiGraph):
97 return plot_tmultigraph(tobject, **kwd)
98
99 elif isinstance(tobject, ROOT.TGraph):
100 return plot_tgraph(tobject, **kwd)
101
102 elif isinstance(tobject, (ROOT.TProfile, ROOT.TH1)):
103 return plot_th1(tobject, **kwd)
104
105 else:
106 raise ValueError("Plotting to matplot lib only supported for TH1, TProfile, and THStack.")
107
108
109def use_style(plot_function):
110 """Decorator to adjust the matplotlib style before plotting"""
111 @functools.wraps(plot_function)
112 def plot_function_with_style(*args, **kwds):
113 style = kwds.get("style", None)
114 if style is None:
115 style = defaults.style
116 matplotlib.style.use(style)
117
118 matplotlib.rcParams["patch.linewidth"] = 2
119 matplotlib.rcParams['mathtext.fontset'] = 'custom'
120 matplotlib.rcParams['mathtext.rm'] = 'DejaVu Sans'
121 matplotlib.rcParams['mathtext.it'] = 'DejaVu Sans:italic'
122 matplotlib.rcParams['mathtext.cal'] = 'DejaVu Sans:italic'
123 matplotlib.rcParams['mathtext.bf'] = 'DejaVu Sans:bold'
124
125 # matplotlib.rcParams['font.family'] = ['sans-serif']
126 matplotlib.rcParams['font.sans-serif'] = ['DejaVu Sans']
127 matplotlib.rcParams['font.monospace'] = ['cmtt10']
128
129 return plot_function(*args, **kwds)
130
131 return plot_function_with_style
132
133
134@use_style
135def plot_th2(th2,
136 label=None,
137 legend=None,
138 style=None,
139 **kwd):
140 """Plots a two dimensional histogram"""
141
142 title = reformat_root_latex_to_matplotlib_latex(th2.GetTitle())
143 fig, ax = create_figure(title=title)
144
145 th2_label = create_label(th2, label)
146 plot_th2_data_into(ax, th2, label=th2_label)
147
148 put_legend(fig, ax, legend=legend)
149 return fig
150
151
152@use_style
153def plot_thstack(thstack,
154 label=None,
155 legend=None,
156 style=None,
157 **kwd):
158 """Plots a stack of histograms"""
159
160 title = create_title(thstack)
161 fig, ax = create_figure(title=title)
162
163 # plot reversed such that signal appears on top
164 ths = list(thstack.GetHists())
165
166 if all(isinstance(th, ROOT.TH3) for th in ths):
167 raise ValueError("Cannot plot a stack of three dimensional histograms")
168
169 elif all(isinstance(th, ROOT.TH2) and not isinstance(th, ROOT.TH3) for th in ths):
170 raise NotImplementedError("Cannot plot a stack of two dimensional histograms")
171
172 elif all(isinstance(th, ROOT.TH1) and not isinstance(th, (ROOT.TH3, ROOT.TH2)) for th in ths):
173 # currently plot only non stacked
174 th1s = ths
175 for th1 in th1s:
176 th1_label = create_label(th1, label=label)
177 plot_th1_data_into(ax, th1, label=th1_label)
178
179 # Fixing that the limits sometimes clip in the y direction
180 max_bin_content = thstack.GetMaximum("nostack")
181 ax.set_ylim(0, 1.02 * max_bin_content)
182
183 else:
184 ValueError("Stack of histograms with mismatching dimensions")
185
186 put_legend(fig, ax, legend=legend)
187 return fig
188
189
190@use_style
191def plot_tmultigraph(tmultigraph,
192 label=None,
193 legend=None,
194 style=None,
195 **kwd):
196 """Plots multiple overlaid graphs"""
197
198 title = create_title(tmultigraph)
199 fig, ax = create_figure(title=title)
200
201 for tgraph in tmultigraph.GetListOfGraphs():
202 tgraph_label = create_label(tgraph, label=label)
203 plot_tgraph_data_into(ax, tgraph, label=tgraph_label)
204
205 y_lower_bound, y_upper_bound = common_bounds(
206 ax.get_ylim(),
207 (tmultigraph.GetMinimum(), tmultigraph.GetMaximum())
208 )
209
210 ax.set_ylim(y_lower_bound, y_upper_bound)
211
212 put_legend(fig, ax, legend=legend)
213 return fig
214
215
216@use_style
217def plot_th1(th1,
218 label=None,
219 legend=None,
220 style=None,
221 **kwd):
222 """Plots a one dimensional histogram including the fit function if present"""
223
224 title = create_title(th1)
225 fig, ax = create_figure(title=title)
226
227 th1_label = create_label(th1, label)
228
229 if th1.GetSumOfWeights() == 0:
230 get_logger().info("Skipping empty histogram %s", th1.GetName())
231 return fig
232
233 plot_th1_data_into(ax, th1, label=th1_label)
234
235 tf1 = get_fit(th1)
236 if tf1:
237 tf1_label = create_label(tf1, label=label)
238 fit_handles = plot_tf1_data_into(ax, tf1, label=tf1_label)
239 else:
240 fit_handles = None
241
242 put_legend(fig, ax, legend=legend, top_handles=fit_handles, top_title="Fit")
243 return fig
244
245
246@use_style
247def plot_tgraph(tgraph,
248 label=None,
249 legend=None,
250 style=None):
251 """Plots graph including the fit function if present"""
252 title = create_title(tgraph)
253 fig, ax = create_figure(title=title)
254
255 tgraph_label = create_label(tgraph, label=label)
256 plot_tgraph_data_into(ax, tgraph, label=tgraph_label)
257
258 tf1 = get_fit(tgraph)
259 if tf1:
260 tf1_label = create_label(tf1, label=label)
261 fit_handles = plot_tf1_data_into(ax, tf1, label=tf1_label)
262 else:
263 fit_handles = None
264
265 put_legend(fig, ax, legend=legend, top_handles=fit_handles, top_title="Fit")
266 return fig
267
268
269def create_figure(title=None):
270 """Create a new figure already making space for a by side legend if requested
271
272 Returns
273 -------
274 (matplotlib.figure.Figure, matplotlib.axes.Axes)
275 A prepared figure and axes into which can be plotted.
276 """
277 fig = plt.figure()
278 ax = fig.add_subplot(111)
279
280 if title:
281 ax.set_title(title, y=1.04)
282
283 return fig, ax
284
285
286def create_title(tplotable):
287 """Extract the title from the plotable ROOT object and translate to ROOT latex"""
288 return reformat_root_latex_to_matplotlib_latex(tplotable.GetTitle())
289
290
291def create_label(th_or_tgraph, label=None):
292 """Create a label from the plotable object incorporating available summary statistics."""
293 if label is None:
294 label = defaults.label
295
296 if label:
297 if isinstance(th_or_tgraph, ROOT.TH1):
298 th = th_or_tgraph
299 stats = get_stats_from_th(th)
300 label = compose_stats_label("", stats)
301
302 elif isinstance(th_or_tgraph, ROOT.TGraph):
303 tgraph = th_or_tgraph
304 stats = get_stats_from_tgraph(tgraph)
305 label = compose_stats_label("", stats)
306
307 elif isinstance(th_or_tgraph, ROOT.TF1):
308 tf1 = th_or_tgraph
309 parameters = get_fit_parameters(tf1)
310 label = compose_stats_label("", parameters)
311
312 else:
313 raise ValueError("Can only create a label from a ROOT.TH or ROOT.TGraph")
314
315 return label
316 return None
317
318
319def put_legend(fig,
320 ax,
321 legend=None,
322 top_handles=None,
323 top_title=None,
324 bottom_title=None):
325 """Put the legend of the plot
326
327 Put one legend to right beside the axes space for some plot handles if desired.
328 Put one legend at the bottom for the remaining plot handles.
329 """
330
331 if legend is None:
332 legend = defaults.legend
333
334 if legend:
335 fig_width = fig.get_figwidth()
336 # Expanding figure by 33 %
337 fig.set_figwidth(1.0 * 4.0 / 3.0 * fig_width)
338
339 # Shink current axis by 25%
340 box = ax.get_position()
341 ax.set_position([box.x0, box.y0, box.width * 0.75, box.height])
342
343 if top_handles:
344 put_legend_outside(ax,
345 exclude_handles=top_handles,
346 force_add_legend=True,
347 title=bottom_title)
348 put_legend_outside(ax,
349 select_handles=top_handles,
350 bottom=False,
351 title=top_title)
352 else:
353 put_legend_outside(ax,
354 exclude_handles=top_handles,
355 title=bottom_title)
356
357
358def put_legend_outside(ax,
359 title=None,
360 bottom=True,
361 select_handles=[],
362 exclude_handles=[],
363 force_add_legend=False,
364 ):
365 """Put a legned right beside the axes space"""
366
367 if not select_handles:
368 select_handles, _ = ax.get_legend_handles_labels()
369
370 if exclude_handles:
371 select_handles = [handle for handle in select_handles if handle not in exclude_handles]
372
373 # fig = ax.get_figure()
374 # trans = transforms.blended_transform_factory(fig.transFigure, ax.transAxes)
375
376 if bottom:
377 # legend_artist = ax.legend(handles=select_handles,
378 # bbox_to_anchor=(1., 0.),
379 # bbox_transform=trans,
380 # borderaxespad=0.,
381 # loc=4,
382 # prop={"family": "monospace"},
383 # title=title)
384
385 legend_artist = ax.legend(handles=select_handles,
386 bbox_to_anchor=(1.02, 0),
387 borderaxespad=0.,
388 loc=3,
389 prop={"family": "monospace"},
390 title=title)
391 else:
392 # legend_artist = ax.legend(handles=select_handles,
393 # bbox_to_anchor=(1., 1.),
394 # bbox_transform=trans,
395 # borderaxespad=0.,
396 # loc=1,
397 # prop={"family": "monospace"},
398 # title=title)
399
400 legend_artist = ax.legend(handles=select_handles,
401 bbox_to_anchor=(1.02, 1),
402 borderaxespad=0.,
403 loc=2,
404 prop={"family": "monospace"},
405 title=title)
406
407 if force_add_legend:
408 ax.add_artist(legend_artist)
409
410
411def get_fit(th1_or_tgraph):
412 """Retrieve a potential fit function form the plotable object"""
413 for tobject in th1_or_tgraph.GetListOfFunctions():
414 if isinstance(tobject, ROOT.TF1):
415 return tobject
416
417
418def get_fit_parameters(tf1):
419 """Retrieve the fitted parameters explicitly excluding fixed parameters.
420
421 Fixed parameters are usually additional stats entries that are already shown in the main
422 legend for the plot.
423 """
424 parameters = collections.OrderedDict()
425
426 formula = tf1.GetTitle()
427 parameters["formula"] = formula
428
429 n_parameters = tf1.GetNpar()
430 for i_parameter in range(n_parameters):
431
432 lower_bound = ctypes.c_double()
433 upper_bound = ctypes.c_double()
434 tf1.GetParLimits(i_parameter, lower_bound, upper_bound)
435
436 name = tf1.GetParName(i_parameter)
437 value = tf1.GetParameter(i_parameter)
438
439 if lower_bound.value == upper_bound.value and lower_bound.value != 0:
440 # fixed parameter, is an additional stats entry
441 continue
442
443 parameters[name] = value
444
445 return parameters
446
447
448def get_stats_from_tgraph(tgraph):
449 """Get the summary statistics from the graph"""
450 stats = collections.OrderedDict()
451
452 additional_stats = ValidationPlot.get_additional_stats(tgraph)
453 if additional_stats:
454 for key, value in list(additional_stats.items()):
455 stats[key] = value
456
457 return stats
458
459
460def get_stats_from_th(th):
461 """Get the summary statistics from the histogram"""
462 stats = collections.OrderedDict()
463
464 stats["count"] = th.GetEntries()
465
466 additional_stats = ValidationPlot.get_additional_stats(th)
467 if additional_stats:
468 for key, value in list(additional_stats.items()):
469 stats[key] = value
470
471 if not isinstance(th, (ROOT.TH2, ROOT.TH3)):
472 x_taxis = th.GetXaxis()
473 n_bins = x_taxis.GetNbins()
474
475 if isinstance(th, ROOT.TProfile):
476 underflow_content = th.GetBinEntries(0)
477 overflow_content = th.GetBinEntries(n_bins + 1)
478 else:
479 underflow_content = th.GetBinContent(0)
480 overflow_content = th.GetBinContent(n_bins + 1)
481
482 if underflow_content:
483 stats["x underf."] = underflow_content
484
485 if overflow_content:
486 stats["x overf."] = overflow_content
487
488 stats_values = np.array([np.nan] * 7)
489 th.GetStats(stats_values)
490
491 sum_w = stats_values[0]
492 # sum_w2 = stats_values[1]
493 sum_wx = stats_values[2]
494 sum_wx2 = stats_values[3]
495 sum_wy = stats_values[4] # Only for TH2 and TProfile
496 sum_wy2 = stats_values[5] # Only for TH2 and TProfile
497 sum_wxy = stats_values[6] # Only for TH2
498
499 if np.isnan(sum_wy):
500 # Only one dimensional
501 stats["x avg"] = np.divide(sum_wx, sum_w)
502 stats["x std"] = np.divide(np.sqrt(sum_wx2 * sum_w - sum_wx * sum_wx), sum_w)
503
504 else:
505 # Only two dimensional
506 stats["x avg"] = np.divide(sum_wx, sum_w)
507 stats["x std"] = np.divide(np.sqrt(sum_wx2 * sum_w - sum_wx * sum_wx), sum_w)
508 stats["y avg"] = np.divide(sum_wy, sum_w)
509 stats["y std"] = np.divide(np.sqrt(sum_wy2 * sum_w - sum_wy * sum_wy), sum_w)
510
511 if not np.isnan(sum_wxy):
512 stats["cov"] = np.divide((sum_wxy * sum_w - sum_wx * sum_wy), (sum_w * sum_w))
513 stats["corr"] = np.divide(stats["cov"], (stats["x std"] * stats["y std"]))
514
515 return stats
516
517
518def compose_stats_label(title, additional_stats=None):
519 """Render the summary statistics to a label string."""
520 if additional_stats is None:
521 additional_stats = {}
522 labeled_value_template = "{0:<9}: {1:.3g}"
523 labeled_string_template = "{0:<9}: {1:>9s}"
524 label_elements = []
525 if title:
526 label_elements.append(str(title))
527
528 for key, value in list(additional_stats.items()):
529 if isinstance(value, str):
530 label_element = labeled_string_template.format(key, value)
531 else:
532 label_element = labeled_value_template.format(key, value)
533 label_elements.append(label_element)
534
535 return "\n".join(label_elements)
536
537
538def plot_tgraph_data_into(ax,
539 tgraph,
540 plot_errors=None,
541 label=None,
542 clip_to_data=True):
543 """Plot a ROOT TGraph into a matplotlib axes
544
545 Parameters
546 ----------
547 ax : matplotlib.axes.Axes
548 An axes space in which to plot
549 tgraph : ROOT.TGraph
550 A plotable one dimensional ROOT histogram
551 plot_errors : bool, optional
552 Plot graph as errorbar plot. Default None means True for TGraphErrors and False else.
553 label : str
554 label to be given to the plot
555 """
556
557 plt.autoscale(tight=clip_to_data)
558
559 if plot_errors is None:
560 if isinstance(tgraph, ROOT.TGraphErrors):
561 plot_errors = True
562 else:
563 plot_errors = False
564
565 x_taxis = tgraph.GetXaxis()
566 y_taxis = tgraph.GetYaxis()
567
568 xlabel = reformat_root_latex_to_matplotlib_latex(x_taxis.GetTitle())
569 ylabel = reformat_root_latex_to_matplotlib_latex(y_taxis.GetTitle())
570
571 n_points = tgraph.GetN()
572
573 xs = np.ndarray((n_points,), dtype=float)
574 ys = np.ndarray((n_points,), dtype=float)
575
576 x_lower_errors = np.ndarray((n_points,), dtype=float)
577 x_upper_errors = np.ndarray((n_points,), dtype=float)
578
579 y_lower_errors = np.ndarray((n_points,), dtype=float)
580 y_upper_errors = np.ndarray((n_points,), dtype=float)
581
582 x = ctypes.c_double()
583 y = ctypes.c_double()
584
585 for i_point in range(n_points):
586 tgraph.GetPoint(i_point, x, y)
587 xs[i_point] = float(x.value)
588 ys[i_point] = float(y.value)
589
590 x_lower_errors[i_point] = tgraph.GetErrorXlow(i_point)
591 x_upper_errors[i_point] = tgraph.GetErrorXhigh(i_point)
592
593 y_lower_errors[i_point] = tgraph.GetErrorYlow(i_point)
594 y_upper_errors[i_point] = tgraph.GetErrorYhigh(i_point)
595
596 if plot_errors:
597 root_color_index = tgraph.GetLineColor()
598 linecolor = root_color_to_matplot_color(root_color_index)
599
600 ax.errorbar(xs,
601 ys,
602 xerr=[x_lower_errors, x_upper_errors],
603 yerr=[y_lower_errors, y_upper_errors],
604 fmt="none",
605 ecolor=linecolor,
606 label=label)
607
608 else:
609 root_color_index = tgraph.GetMarkerColor()
610 markercolor = root_color_to_matplot_color(root_color_index)
611
612 ax.scatter(xs,
613 ys,
614 c=markercolor,
615 s=2,
616 marker="+",
617 label=label)
618 x_lower_bound, x_upper_bound = ax.get_xlim()
619 x_lower_bound = min(x_lower_bound, np.nanmin(xs))
620
621 ax.set_xlabel(xlabel)
622 ax.set_ylabel(ylabel)
623
624 plt.autoscale(tight=None)
625
626
627def plot_th1_data_into(ax,
628 th1,
629 plot_errors=None,
630 label=None):
631 """Plot a ROOT histogram into a matplotlib axes
632
633 Parameters
634 ----------
635 ax : matplotlib.axes.Axes
636 An axes space in which to plot
637 th1 : ROOT.TH1
638 A plotable one dimensional ROOT histogram
639 plot_errors : bool, optional
640 Plot histogram as errorbar plot. Default None means True for TProfile and False else.
641 label : str, optional
642 label to be given to the histogram
643 """
644
645 if plot_errors is None:
646 if isinstance(th1, ROOT.TProfile):
647 plot_errors = True
648 else:
649 plot_errors = th1.GetSumw2N() != 0
650
651 # Bin content
652 x_taxis = th1.GetXaxis()
653 n_bins = x_taxis.GetNbins()
654
655 bin_ids_with_underflow = list(range(n_bins + 1))
656 bin_ids_without_underflow = list(range(1, n_bins + 1))
657
658 # Get the n_bins + 1 bin edges starting from the underflow bin 0
659 bin_edges = np.array([x_taxis.GetBinUpEdge(i_bin) for i_bin in bin_ids_with_underflow])
660
661 # Bin center and width not including the underflow
662 bin_centers = np.array([x_taxis.GetBinCenter(i_bin) for i_bin in bin_ids_without_underflow])
663 bin_widths = np.array([x_taxis.GetBinWidth(i_bin) for i_bin in bin_ids_without_underflow])
664 bin_x_errors = bin_widths / 2.0
665
666 # Now for the histogram content not including the underflow
667 bin_contents = np.array([th1.GetBinContent(i_bin) for i_bin in bin_ids_without_underflow])
668 bin_y_errors = np.array([th1.GetBinError(i_bin) for i_bin in bin_ids_without_underflow])
669 bin_y_upper_errors = np.array([th1.GetBinErrorUp(i_bin) for i_bin in bin_ids_without_underflow])
670 bin_y_lower_errors = np.array([th1.GetBinErrorLow(i_bin) for i_bin in bin_ids_without_underflow])
671
672 empty_bins = (bin_contents == 0) & (bin_y_errors == 0)
673
674 is_discrete_binning = bool(x_taxis.GetLabels())
675 bin_labels = [x_taxis.GetBinLabel(i_bin) for i_bin in bin_ids_without_underflow]
676
677 xlabel = reformat_root_latex_to_matplotlib_latex(x_taxis.GetTitle())
678
679 y_taxis = th1.GetYaxis()
680 ylabel = reformat_root_latex_to_matplotlib_latex(y_taxis.GetTitle())
681
682 # May set these from th1 properties
683 y_log_scale = False
684 histtype = "step"
685
686 root_color_index = th1.GetLineColor()
687 linecolor = root_color_to_matplot_color(root_color_index)
688
689 if plot_errors:
690 ax.errorbar(bin_centers[~empty_bins],
691 bin_contents[~empty_bins],
692 yerr=[bin_y_lower_errors[~empty_bins],
693 bin_y_upper_errors[~empty_bins]],
694 xerr=bin_x_errors[~empty_bins],
695 fmt="none",
696 ecolor=linecolor,
697 label=label)
698
699 y_lower_bound, y_upper_bound = common_bounds(
700 ax.get_ylim(),
701 (th1.GetMinimum(flt_min), th1.GetMaximum(flt_max))
702 )
703
704 y_total_width = y_upper_bound - y_lower_bound
705 ax.set_ylim(y_lower_bound - 0.02 * y_total_width, y_upper_bound + 0.02 * y_total_width)
706
707 else:
708 if is_discrete_binning:
709 ax.bar(bin_centers - 0.4,
710 bin_contents,
711 width=0.8,
712 label=label,
713 edgecolor=linecolor,
714 color=(1, 1, 1, 0), # fill transparent white
715 log=y_log_scale)
716 else:
717 ax.hist(bin_centers,
718 bins=bin_edges,
719 weights=bin_contents,
720 edgecolor=linecolor,
721 histtype=histtype,
722 label=label,
723 log=y_log_scale)
724
725 # Fixing that the limits sometimes clip in the y direction
726 ax.set_ylim(0, 1.02 * max(bin_contents))
727
728 if is_discrete_binning:
729 ax.set_xticks(bin_centers)
730 ax.set_xticklabels(bin_labels, rotation=0)
731
732 total_width = bin_edges[-1] - bin_edges[0]
733 if total_width != 0:
734 ax.set_xlim(bin_edges[0] - 0.01 * total_width, bin_edges[-1] + 0.01 * total_width)
735
736 ax.set_xlabel(xlabel)
737 ax.set_ylabel(ylabel)
738
739
740def plot_th2_data_into(ax,
741 th2,
742 plot_3d=False,
743 label=None):
744 """Plot a ROOT histogram into a matplotlib axes
745
746 Parameters
747 ----------
748 ax : matplotlib.axes.Axes
749 An axes space in which to plot
750 th2 : ROOT.TH2
751 A plotable two dimensional ROOT histogram
752 plot_3d : bool, optional
753 Plot as a three dimensional plot
754 label : str, optional
755 label to be given to the histogram
756 """
757
758 # Bin content
759 x_taxis = th2.GetXaxis()
760 y_taxis = th2.GetYaxis()
761
762 x_n_bins = x_taxis.GetNbins()
763 y_n_bins = y_taxis.GetNbins()
764
765 x_bin_ids_with_underflow = list(range(x_n_bins + 1))
766 y_bin_ids_with_underflow = list(range(y_n_bins + 1))
767
768 x_bin_ids_without_underflow = list(range(1, x_n_bins + 1))
769 y_bin_ids_without_underflow = list(range(1, y_n_bins + 1))
770
771 # Get the n_bins + 1 bin edges starting from the underflow bin 0
772 x_bin_edges = np.array([x_taxis.GetBinUpEdge(i_bin) for i_bin in x_bin_ids_with_underflow])
773 y_bin_edges = np.array([y_taxis.GetBinUpEdge(i_bin) for i_bin in y_bin_ids_with_underflow])
774
775 # Bin center and width not including the underflow
776 x_bin_centers = np.array([x_taxis.GetBinCenter(i_bin) for i_bin in x_bin_ids_without_underflow])
777 y_bin_centers = np.array([y_taxis.GetBinCenter(i_bin) for i_bin in y_bin_ids_without_underflow])
778
779 x_centers, y_centers = np.meshgrid(x_bin_centers, y_bin_centers)
780
781 x_ids, y_ids = np.meshgrid(x_bin_ids_without_underflow, y_bin_ids_without_underflow)
782
783 bin_contents = np.array([th2.GetBinContent(int(x_i_bin), int(y_i_bin))
784 for x_i_bin, y_i_bin in zip(x_ids.flat, y_ids.flat)])
785
786 x_is_discrete_binning = bool(x_taxis.GetLabels())
787 x_bin_labels = [x_taxis.GetBinLabel(i_bin) for i_bin in x_bin_ids_without_underflow]
788
789 y_is_discrete_binning = bool(y_taxis.GetLabels())
790 y_bin_labels = [y_taxis.GetBinLabel(i_bin) for i_bin in y_bin_ids_without_underflow]
791
792 xlabel = reformat_root_latex_to_matplotlib_latex(x_taxis.GetTitle())
793 ylabel = reformat_root_latex_to_matplotlib_latex(y_taxis.GetTitle())
794
795 # May set these from th2 properties
796 log_scale = False
797
798 # root_color_index = th2.GetLineColor()
799 # linecolor = root_color_to_matplot_color(root_color_index)
800
801 if plot_3d:
802 raise NotImplementedError("3D plotting of two dimensional histograms not implemented yet")
803
804 else:
805 if log_scale:
806 _, _, _, colorbar_mappable = ax.hist2d(x_centers.flatten(),
807 y_centers.flatten(),
808 weights=bin_contents,
809 bins=[x_bin_edges, y_bin_edges],
810 label=label,
811 norm=matplotlib.colors.LogNorm())
812 else:
813 _, _, _, colorbar_mappable = ax.hist2d(x_centers.flatten(),
814 y_centers.flatten(),
815 weights=bin_contents,
816 bins=[x_bin_edges, y_bin_edges],
817 label=label)
818
819 lowest_color = colorbar_mappable.get_cmap()(0)
820
821 # Dummy artist to show in legend
822 ax.fill(0, 0, "-", color=lowest_color, label=label)
823
824 # For colorbar on the left
825 # cbar_ax, _ = matplotlib.colorbar.make_axes(ax,
826 # location="left",
827 # )
828
829 colorbar_ax, _ = matplotlib.colorbar.make_axes(ax,
830 pad=0.02,
831 shrink=0.5,
832 anchor=(0.0, 1.0),
833 panchor=(1.0, 1.0),
834 aspect=10,
835 )
836
837 matplotlib.colorbar.Colorbar(colorbar_ax, colorbar_mappable)
838
839 if x_is_discrete_binning:
840 ax.set_xticks(x_bin_centers)
841 ax.set_xticklabels(x_bin_labels, rotation=0)
842
843 if y_is_discrete_binning:
844 ax.set_yticks(y_bin_centers)
845 ax.set_yticklabels(y_bin_labels, rotation=0)
846
847 ax.set_xlabel(xlabel)
848 ax.set_ylabel(ylabel)
849
850
851def plot_tf1_data_into(ax,
852 tf1,
853 label=None):
854 """Plots the data of the tf1 into a matplotlib axes
855
856 Parameters
857 ----------
858 ax : matplotlib.axes.Axes
859 An axes space in which to plot
860 tf1 : ROOT.TF1
861 Function to be plotted.
862 label : str, optional
863 Label for the legend entry.
864 """
865
866 lower_bound = tf1.GetXmin()
867 upper_bound = tf1.GetXmin()
868 n_plot_points = max(tf1.GetNpx(), 100)
869
870 if lower_bound == upper_bound:
871 lower_bound, upper_bound = ax.get_xlim()
872
873 xs = np.linspace(lower_bound, upper_bound, n_plot_points)
874 ys = [tf1.Eval(x) for x in xs]
875
876 root_color_index = tf1.GetLineColor()
877 linecolor = root_color_to_matplot_color(root_color_index)
878
879 if any(y != 0 for y in ys):
880 line_handles = ax.plot(xs, ys, color=linecolor, label=label)
881 return line_handles
882
883
884def root_color_to_matplot_color(root_color_index):
885 """Translates ROOT color into an RGB tuple.
886
887 Parameters
888 ----------
889 root_color_index : int
890 Index of a color as defined in ROOT
891
892 Returns
893 -------
894 (float, float, float)
895 tuple of floats that represent to RGB color fractions.
896 """
897 tcolor = ROOT.gROOT.GetColor(root_color_index)
898 return (tcolor.GetRed(), tcolor.GetGreen(), tcolor.GetBlue())
899
900
901def reformat_root_latex_to_matplotlib_latex(text):
902 """Takes text that may contain ROOT pseudo latex directives and
903 translate it in to proper latex that can be understood by matplotlib"""
904
905 # Dumb implementation, can be improved a lot
906 # Splits by white space and try to treat every part separately.
907 # Additionally a dump regular expression replacement for ROOT latex directives
908 # is applied.
909 # It may loose some context,
910 # but I could not bother to implement a parser for ROOTs latex dialect.
911
912 text_parts = text.split(" ")
913
914 reformatted_text_parts = []
915
916 for text_part in text_parts:
917 if is_root_latex_directive(text_part):
918 # All directive are wrapped into math mode
919 reformatted_text_part = r'$' + text_part.replace('#', '\\') + r'$'
920 # print 'Format' , text_part ,'to' , reformatted_text_part
921 else:
922 reformatted_text_part = text_part
923
924 reformatted_text_part = re.sub("#([a-zA-Z_{}]*)", r"$\\\1$", reformatted_text_part)
925
926 reformatted_text_parts.append(reformatted_text_part)
927
928 return " ".join(reformatted_text_parts)
929
930
931def is_root_latex_directive(text_part):
932 """Test if a text part looks like a ROOT latex directive"""
933 return text_part.startswith('#') or '_{' in text_part or '{}' in text_part
934
935
936def common_bounds(matplot_bounds, root_bounds):
937 """Assign the common lower and upper bounds for a plot"""
938 lower_bound, upper_bound = matplot_bounds
939 root_lower_bound, root_upper_bound = root_bounds
940
941 if root_lower_bound != 0 or root_upper_bound != 0:
942 lower_bound = np.nanmin((lower_bound, root_lower_bound))
943 upper_bound = np.nanmax((upper_bound, root_upper_bound))
944
945 return lower_bound, upper_bound
Definition: plot.py:1