Belle II Software release-09-00-14
plotting.py
1
8"""
9Plotting functions for alignment validation.
10
11Module-level config (set by caller before running):
12 output_dir — directory where plots are saved (default: "validation_plots")
13 file_format — image format (default: "pdf")
14"""
15
16import re
17
18import matplotlib.pyplot as plt
19import numpy as np
20from scipy.optimize import curve_fit
21
22from alignment_validation.utils import auto_range, normal_distribution, to_bins
23
24## Output directory for saved plots
25output_dir = "validation_plots"
26## Image file format
27file_format = "pdf"
28
29# Global typography defaults
30plt.rcParams.update({
31 "axes.labelsize": 14,
32 "axes.titlesize": 14,
33 "xtick.labelsize": 11,
34 "ytick.labelsize": 11,
35 "legend.fontsize": 12,
36 "figure.titlesize": 16,
37})
38
39
40def plot_histogram(
41 data_list: list,
42 data_labels: list,
43 plot_filename: str,
44 xlabel: str,
45 range: int = 96,
46 opacity: float = 0.5,
47 ):
48 """Plot overlaid histograms for each dataset and save to the output directory.
49
50 The histogram range is determined automatically from the data unless an
51 explicit range is provided.
52
53 Parameters
54 ----------
55 data_list : list of array-like
56 One array per dataset to histogram.
57 data_labels : list of str
58 Legend labels, one per dataset.
59 plot_filename : str
60 Output file base name (without extension). The file is saved as
61 ``{output_dir}/{plot_filename}.{file_format}``.
62 xlabel : str
63 x-axis label (supports LaTeX).
64 range : int or float or tuple, optional
65 If a number, the central ``range``% of the union of all datasets is
66 used as the histogram range (via :func:`auto_range`). If a ``(min,
67 max)`` tuple, it is used directly. Default is 96.
68 opacity : float, optional
69 Histogram bar opacity (alpha). Default is 0.5.
70 """
71 if isinstance(range, (int, float)):
72 hist_ranges = auto_range(data_list, range, 0.1)
73 else:
74 hist_ranges = range
75
76 plt.figure(figsize=(5, 4))
77 plt.ticklabel_format(scilimits=(-2, 3), useMathText=True)
78 plt.ylabel("Entries")
79 plt.xlabel(xlabel)
80 for i, data in enumerate(data_list):
81 plt.hist(data, bins=30, range=hist_ranges, alpha=opacity, density=False, label=data_labels[i], rasterized=True)
82
83 plt.legend()
84 plt.savefig(f"{output_dir}/{plot_filename}.{file_format}", bbox_inches="tight")
85 plt.close()
86
87
88def plot_correlations(
89 plot_type: str,
90 x_data_list: list,
91 y_data_list: list,
92 x_data_labels: list,
93 y_data_labels: list,
94 data_labels: list,
95 nbins: int,
96 ranges: tuple = (96, 90),
97 make_2D_hist: bool = True,
98 figsize: tuple = (10.0, 7.5),
99 ):
100 """Plot a grid of per-bin median or resolution profiles for all x/y variable pairs.
101
102 Produces one subplot per (x variable, y variable) combination. Optionally
103 also produces matching 2D histogram plots via
104 :func:`plot_correlations_2D_histograms`.
105
106 Saves to ``{output_dir}/{plot_type}_correlations.{file_format}``.
107
108 Parameters
109 ----------
110 plot_type : str
111 ``'median'`` to plot per-bin medians, or ``'resolution'`` to plot
112 per-bin sigma68 values.
113 x_data_list : list of dict
114 One dict per dataset, mapping variable objects to x arrays.
115 y_data_list : list of dict
116 One dict per dataset, mapping variable objects to y arrays. Must have
117 the same keys across all datasets.
118 x_data_labels : list of str
119 x-axis labels for each x variable, in the same order as the dict keys.
120 y_data_labels : list of str
121 y-axis labels for each y variable, in the same order as the dict keys.
122 data_labels : list of str
123 Legend labels, one per dataset.
124 nbins : int
125 Number of bins for the profile.
126 ranges : tuple, optional
127 ``(x_range, y_range)`` where each element is either a percentile
128 integer/float (auto-computed via :func:`auto_range`) or a pre-computed
129 dict mapping variable objects to ``[min, max]`` limits.
130 Default is ``(96, 90)``.
131 make_2D_hist : bool, optional
132 If True, also call :func:`plot_correlations_2D_histograms` for each
133 dataset. Default is True.
134 figsize : tuple of float, optional
135 Figure size in inches ``(width, height)``. Default is ``(10.0, 7.5)``.
136 """
137 fig, axs = plt.subplots(len(y_data_list[0]), len(x_data_list[0]), sharex="col", sharey="row")
138 fig.set_size_inches(*figsize)
139 if plot_type == 'median':
140 fig.suptitle("Correlations")
141 elif plot_type == 'resolution':
142 fig.suptitle("Resolution relations")
143 else:
144 raise ValueError("plot_type must be 'median' or 'resolution'")
145 fig.subplots_adjust(wspace=0.2, hspace=0.2)
146
147 if isinstance(ranges[0], (int, float)):
148 percentage = ranges[0]
149 ranges = ({}, ranges[1])
150 for var in x_data_list[0].keys():
151 temp = []
152 for i, _ in enumerate(x_data_list):
153 temp.append(auto_range([x_data_list[i][var]], percentage, modify=0.1))
154 ranges[0][var] = [min(temp, key=lambda x: x[0])[0], max(temp, key=lambda x: x[1])[1]]
155 if isinstance(ranges[1], (int, float)):
156 percentage = ranges[1]
157 ranges = (ranges[0], {})
158 for var in y_data_list[0].keys():
159 temp = []
160 for i, _ in enumerate(y_data_list):
161 if plot_type == 'median':
162 temp.append(auto_range([y_data_list[i][var]], percentage, modify=0.1, symmetric=True))
163 elif plot_type == 'resolution':
164 temp.append(
165 (0,
166 np.percentile(y_data_list[i][var], 50 + percentage / 2) -
167 np.percentile(y_data_list[i][var], 50 - percentage / 2)))
168 ranges[1][var] = [min(temp, key=lambda x: x[0])[0], max(temp, key=lambda x: x[1])[1]]
169
170 for y, vary in enumerate(y_data_list[0].keys()):
171 for x, varx in enumerate(x_data_list[0].keys()):
172 for i, _ in enumerate(x_data_list):
173 if y == (len(y_data_list[0].keys()) - 1):
174 axs[y, x].set_xlabel(x_data_labels[x], fontsize=10)
175 if x == 0:
176 axs[y, x].set_ylabel(y_data_labels[y], fontsize=10)
177
178 if plot_type == 'median':
179 x_vals, y_vals, x_width, y_err, _, _ = to_bins(
180 x_data_list[i][varx], y_data_list[i][vary], nbins, ranges[0][varx])
181 elif plot_type == 'resolution':
182 x_vals, _, x_width, _, y_vals, y_err = to_bins(
183 x_data_list[i][varx], y_data_list[i][vary], nbins, ranges[0][varx])
184
185 axs[y, x].tick_params(labelsize=9)
186 with np.errstate(invalid='ignore'):
187 axs[y, x].errorbar(x_vals, y_vals, y_err, x_width, fmt=".", label=data_labels[i])
188
189 axs[y, x].set_xlim(ranges[0][varx][0], ranges[0][varx][1])
190 axs[y, x].set_ylim(ranges[1][vary][0], ranges[1][vary][1])
191
192 fig.legend(data_labels, loc="upper center", ncol=len(data_labels), bbox_to_anchor=(0.5, 0.95))
193 plt.savefig(format=f"{file_format}", fname=f"{output_dir}/{plot_type}_correlations.{file_format}")
194 plt.close()
195
196 if make_2D_hist:
197 for i, _ in enumerate(x_data_list):
198 plot_correlations_2D_histograms(
199 x_data_list[i], y_data_list[i],
200 x_data_labels, y_data_labels,
201 data_labels[i], nbins, ranges, figsize=figsize,
202 )
203
204
205def plot_correlations_2D_histograms(
206 x_data: dict,
207 y_data: dict,
208 x_data_labels: list,
209 y_data_labels: list,
210 data_label: str,
211 nbins: int,
212 ranges: tuple,
213 figsize: tuple = (10.0, 7.5),
214 ):
215 """Plot a grid of 2D histograms showing correlations between all x/y variable pairs.
216
217 Saves to ``{output_dir}/Correlations_2dhist_{data_label}.{file_format}``.
218
219 Parameters
220 ----------
221 x_data : dict
222 Maps variable objects to x arrays for a single dataset.
223 y_data : dict
224 Maps variable objects to y arrays for a single dataset.
225 x_data_labels : list of str
226 x-axis labels for each x variable.
227 y_data_labels : list of str
228 y-axis labels for each y variable.
229 data_label : str
230 Dataset label used in the figure title and file name.
231 nbins : int
232 Number of bins along the x axis of each 2D histogram. The y axis
233 uses a fixed 10 bins.
234 ranges : tuple of (dict, dict)
235 Pre-computed ``(x_ranges, y_ranges)`` dicts mapping variable objects
236 to ``[min, max]`` limits, as produced by :func:`plot_correlations`.
237 figsize : tuple of float, optional
238 Figure size in inches ``(width, height)``. Default is ``(10.0, 7.5)``.
239 """
240 fig, axs = plt.subplots(len(y_data), len(x_data), sharex="col", sharey="row")
241 fig.set_size_inches(*figsize)
242 fig.suptitle(f"Correlations ({data_label})")
243 fig.subplots_adjust(wspace=0.2, hspace=0.2)
244
245 for y, vary in enumerate(y_data.keys()):
246 for x, varx in enumerate(x_data.keys()):
247 axs[y, x].set_xlim(ranges[0][varx])
248 axs[y, x].set_ylim(ranges[1][vary])
249
250 if y == (len(y_data.keys()) - 1):
251 axs[y, x].set_xlabel(x_data_labels[x], fontsize=10)
252 if x == 0:
253 axs[y, x].set_ylabel(y_data_labels[y], fontsize=10)
254
255 with np.errstate(invalid='ignore'):
256 axs[y, x].hist2d(x_data[varx], y_data[vary], [nbins, 10],
257 [axs[y, x].get_xlim(), axs[y, x].get_ylim()], rasterized=True)
258
259 axs[y, x].set_xlim(ranges[0][varx][0], ranges[0][varx][1])
260 axs[y, x].set_ylim(ranges[1][vary][0], ranges[1][vary][1])
261
262 fig.tight_layout()
263 plt.savefig(format=f"{file_format}", fname=f"{output_dir}/Correlations_2dhist_{data_label}.{file_format}")
264 plt.close()
265
266
267def plot_2D_histogram(data: dict, label: str, bins: tuple, phi_var, tan_lambda_var):
268 """Plot a 2D histogram of track occupancy in the phi vs tan(lambda) plane.
269
270 Both tracks are combined into a single histogram. Saves to
271 ``{output_dir}/map_2dhist_{label}.{file_format}``.
272
273 Parameters
274 ----------
275 data : dict
276 Data dictionary mapping branch names to arrays.
277 label : str
278 Dataset label used in the file name.
279 bins : tuple of int
280 ``(n_phi_bins, n_tan_lambda_bins)`` for the 2D histogram.
281 phi_var : TrackVariable
282 Variable used for the y axis (phi).
283 tan_lambda_var : TrackVariable
284 Variable used for the x axis (tan lambda).
285 """
286 plt.figure()
287 plt.title("Detector map histogram")
288 plt.hist2d(
289 np.concatenate((data[tan_lambda_var.name1], data[tan_lambda_var.name2])),
290 np.concatenate((data[phi_var.name1], data[phi_var.name2])),
291 (bins[1], bins[0]),
292 range=((-2, 3), (-np.pi, np.pi)),
293 rasterized=True,
294 )
295 plt.colorbar(label="Events")
296 plt.ylabel(phi_var.latex + phi_var.unit.name, fontsize=12)
297 plt.xlabel(tan_lambda_var.latex + tan_lambda_var.unit.name, fontsize=12)
298 plt.savefig(format=f"{file_format}", fname=f"{output_dir}/map_2dhist_{label}.{file_format}")
299 plt.close()
300
301
302def draw_map(
303 map_type: str,
304 data: dict,
305 label: str,
306 variable,
307 observable_mode: str,
308 bins: tuple,
309 phi_var,
310 tan_lambda_var,
311 vmin: float = None,
312 vmax: float = None,
313 ):
314 """Draw a median or resolution detector map for ``variable`` binned in phi vs tan(lambda).
315
316 The map is built by iterating over phi strips and computing per-bin
317 statistics along tan(lambda) within each strip using :func:`to_bins`. The
318 caller explicitly chooses how track pairs are combined via
319 ``observable_mode``:
320 - ``'delta'``: difference (track1 - track2)
321 - ``'sigma'``: sum (track1 + track2)
322
323 In this codebase, ``sigma`` is intended for dimuon d0 observables, while
324 cosmics maps typically use ``delta``.
325
326 Saves to ``{output_dir}/{map_type}_map_{variable.plaintext}_{label}.{file_format}``.
327
328 Parameters
329 ----------
330 map_type : str
331 ``'median'`` to map the per-bin median, or ``'resolution'`` to map
332 the per-bin sigma68.
333 data : dict
334 Data dictionary mapping branch names to arrays.
335 label : str
336 Dataset label used in the figure title and file name.
337 variable : TrackVariable
338 Observable to map (e.g. ``d`` or ``z``).
339 observable_mode : str
340 Combination mode for track1/track2 values:
341 - ``'delta'`` for difference (track1 - track2)
342 - ``'sigma'`` for sum (track1 + track2)
343 bins : tuple of int
344 ``(n_phi_bins, n_tan_lambda_bins)`` for the map grid.
345 phi_var : TrackVariable
346 Variable providing the phi coordinate (y axis of the map).
347 tan_lambda_var : TrackVariable
348 Variable providing the tan(lambda) coordinate (x axis of the map).
349 vmin : float, optional
350 Minimum of the colour scale. Auto-computed from data if not given.
351 vmax : float, optional
352 Maximum of the colour scale. Auto-computed from data if not given.
353 """
354 if isinstance(bins[0], int):
355 xdim = bins[0]
356 else:
357 xdim = len(bins[0]) - 1
358 if isinstance(bins[1], int):
359 ydim = bins[1]
360 else:
361 ydim = len(bins[1]) - 1
362 map_data = np.zeros((xdim, ydim))
363
364 plt.figure()
365 if map_type == "median":
366 plt.title(f"Median map ({label})")
367 elif map_type == "resolution":
368 plt.title(f"Resolution map ({label})")
369 else:
370 raise ValueError("map_type must be 'median' or 'resolution'")
371 if observable_mode not in {"delta", "sigma"}:
372 raise ValueError("observable_mode must be 'delta' or 'sigma'")
373
374 x_bins = np.histogram_bin_edges(
375 np.concatenate((data[tan_lambda_var.name1], data[tan_lambda_var.name2])), bins[1], (-2, 3))
376 y_bins = np.histogram_bin_edges(
377 np.concatenate((data[phi_var.name1], data[phi_var.name2])), bins[0], (-np.pi, np.pi))
378
379 for i, _ in enumerate(y_bins):
380 if (i + 1) == len(y_bins):
381 continue
382 tracks1 = np.logical_and(y_bins[i] <= data[phi_var.name1], data[phi_var.name1] <= y_bins[i + 1])
383 tracks2 = np.logical_and(y_bins[i] <= data[phi_var.name2], data[phi_var.name2] <= y_bins[i + 1])
384 if observable_mode == "sigma":
385 values = np.concatenate((
386 data[variable.name1][tracks1] + data[variable.name2][tracks1],
387 data[variable.name1][tracks2] + data[variable.name2][tracks2],
388 ))
389 else:
390 values = np.concatenate((
391 data[variable.name1][tracks1] - data[variable.name2][tracks1],
392 data[variable.name1][tracks2] - data[variable.name2][tracks2],
393 ))
394
395 if map_type == "median":
396 _, map_data[i], _, _, _, _ = to_bins(
397 np.concatenate((data[tan_lambda_var.name1][tracks1], data[tan_lambda_var.name2][tracks2])),
398 values, bins[1], (-2, 3),
399 )
400 elif map_type == "resolution":
401 _, _, _, _, map_data[i], _ = to_bins(
402 np.concatenate((data[tan_lambda_var.name1][tracks1], data[tan_lambda_var.name2][tracks2])),
403 values, bins[1], (-2, 3),
404 )
405
406 flat = map_data[~np.isnan(map_data)].flatten()
407 if vmin is None and vmax is None:
408 rawmin, rawmax = auto_range([flat], 98)
409 vmin = rawmin / 2 ** 0.5 * variable.unit.convert
410 vmax = rawmax / 2 ** 0.5 * variable.unit.convert
411 elif vmin is None:
412 rawmin, _ = auto_range([flat], 98)
413 vmin = rawmin / 2 ** 0.5 * variable.unit.convert
414 elif vmax is None:
415 _, rawmax = auto_range([flat], 98)
416 vmax = rawmax / 2 ** 0.5 * variable.unit.convert
417
418 plt.pcolormesh(x_bins, y_bins, map_data / 2 ** 0.5 * variable.unit.convert, vmax=vmax, vmin=vmin, rasterized=True)
419 if map_type == "median":
420 if observable_mode == "sigma":
421 plt.colorbar(label=f"$\\tilde{{}}$($\\Sigma${variable.latex})/$\\sqrt{{2}}$" + variable.unit.dname)
422 else:
423 plt.colorbar(label=f"$\\tilde{{}}$($\\Delta${variable.latex})/$\\sqrt{{2}}$" + variable.unit.dname)
424 elif map_type == "resolution":
425 if observable_mode == "sigma":
426 plt.colorbar(label=f"$\\sigma_{{68}}$($\\Sigma${variable.latex})/$\\sqrt{{2}}$" + variable.unit.dname)
427 else:
428 plt.colorbar(label=f"$\\sigma_{{68}}$($\\Delta${variable.latex})/$\\sqrt{{2}}$" + variable.unit.dname)
429 plt.ylabel(phi_var.latex + phi_var.unit.name, fontsize=12)
430 plt.xlabel(tan_lambda_var.latex + tan_lambda_var.unit.name, fontsize=12)
431 plt.savefig(format=f"{file_format}",
432 fname=f"{output_dir}/{map_type}_map_{variable.plaintext}_{label}.{file_format}")
433 plt.close()
434
435
436def plot_resolutions_hist(
437 suptitle: str,
438 data: dict,
439 labels: dict,
440 nbins: float,
441 ranges=90,
442 vars_to_fit: list = [],
443 shape: tuple = (2, 3),
444 figsize: tuple = (11.0, 8.0),
445 ):
446 """Plot a grid of residual histograms, optionally with Gaussian fits, for each variable.
447
448 Saves to ``{output_dir}/{suptitle}.{file_format}`` (spaces replaced by underscores).
449
450 Parameters
451 ----------
452 suptitle : str
453 Figure title, also used to derive the output file name.
454 data : dict
455 Maps variable objects to 1-D arrays of residual values.
456 labels : dict
457 Maps variable objects to x-axis label strings.
458 nbins : int
459 Number of histogram bins.
460 ranges : int or float or dict, optional
461 If a number, the central ``ranges``% of each variable's data is used as the histogram range (symmetric, via :func:`auto_range`). If a dict,
462 maps variable objects to explicit ``(min, max)`` tuples. Default is 90.
463 vars_to_fit : list, optional
464 Subset of variables for which a Gaussian fit is overlaid and
465 annotated. Default is ``[]`` (no fits).
466 shape : tuple of int, optional
467 ``(nrows, ncols)`` layout of the subplot grid. Default is ``(2, 3)``.
468 figsize : tuple of float, optional
469 Figure size in inches ``(width, height)``. Default is ``(9.0, 6.0)``.
470 """
471 fig, axs = plt.subplots(shape[0], shape[1])
472 fig.suptitle(suptitle, y=0.98)
473 fig.set_size_inches(*figsize)
474 fig.subplots_adjust(wspace=0.4, hspace=0.6, top=0.85)
475
476 for i, var in enumerate(data.keys()):
477 if shape[0] >= 2 and shape[1] >= 2:
478 ax = axs[i // shape[1], i % shape[1]]
479 elif shape[0] < 2 and shape[1] >= 2:
480 ax = axs[i % shape[1]]
481 elif shape[0] >= 2 and shape[1] < 2:
482 ax = axs[i // shape[1]]
483 else:
484 ax = axs
485
486 ax.set_xlabel(labels[var])
487
488 if isinstance(ranges, (int, float)):
489 bounds = auto_range([data[var]], ranges, modify=0.1, symmetric=True)
490 else:
491 bounds = ranges[var]
492
493 nphist = np.histogram(data[var], nbins, range=bounds)
494 x = np.linspace(bounds[0], bounds[1], nbins)
495 ax.hist(data[var], nbins, range=bounds, rasterized=True)
496 ax.set_ylabel("Entries")
497
498 if var in vars_to_fit:
499 try:
500 fit, cov = curve_fit(normal_distribution, x, nphist[0], (5000, 0, 1))
501 err = np.sqrt(np.diag(cov))
502 ax.plot(x, normal_distribution(x, fit[0], fit[1], fit[2]), "k")
503 brackets = re.findall(r'\[(.*?)\]', labels[var])
504 used_units = brackets[-1] if brackets else None
505 fit_parameters = (
506 f"a = {fit[0]:.3}" + r" $\pm$ " + f"{err[0]:.1} " + "\n" +
507 fr"$\mu$ = {fit[1]:.3}" + r" $\pm$ " + f"{err[1]:.1} " + used_units + "\n" +
508 fr"$\sigma$ = {fit[2]:.3}" + r" $\pm$ " + f"{err[2]:.1} " + used_units
509 )
510 ax.text(ax.get_xlim()[0], ax.get_ylim()[1], fit_parameters, size=11, va='bottom')
511 except Exception:
512 print(f"[Warning] Failed to fit {var.plaintext}")
513
514 plt.savefig(format=f"{file_format}", fname=f"{output_dir}/{suptitle.replace(' ', '_')}.{file_format}")
515 plt.close()
516
517
518def plot_resolution_comparison(
519 suptitle: str,
520 data_list: list,
521 data_labels: list,
522 labels: dict,
523 nbins: float,
524 ranges=90,
525 shape: tuple = (2, 3),
526 figsize: tuple = (11.0, 8.0),
527 ):
528 """Overlay residual distributions from multiple datasets and annotate with median and sigma68.
529
530 Each subplot shows histograms from all datasets overlaid, with per-dataset
531 median and sigma68 annotated in the legend. Saves to
532 ``{output_dir}/{suptitle}.{file_format}`` (spaces replaced by underscores).
533
534 Parameters
535 ----------
536 suptitle : str
537 Figure title, also used to derive the output file name.
538 data_list : list of dict
539 One dict per dataset, mapping variable objects to residual arrays.
540 data_labels : list of str
541 Legend labels, one per dataset.
542 labels : dict
543 Maps variable objects to x-axis label strings.
544 nbins : int
545 Number of histogram bins.
546 ranges : int or float or dict, optional
547 If a number, the central ``ranges``% of each variable's data is used
548 as the histogram range (symmetric, via :func:`auto_range`). If a dict,
549 maps variable objects to explicit ``(min, max)`` tuples. Default is 90.
550 shape : tuple of int, optional
551 ``(nrows, ncols)`` layout of the subplot grid. Default is ``(2, 3)``.
552 figsize : tuple of float, optional
553 Figure size in inches ``(width, height)``. Default is ``(9.0, 6.0)``.
554 """
555 fig, axs = plt.subplots(shape[0], shape[1])
556 fig.suptitle(suptitle, y=0.98)
557 fig.subplots_adjust(wspace=0.4, hspace=0.6, top=0.85)
558 fig.set_size_inches(*figsize)
559 fits = [{} for _ in data_list]
560
561 for i, data in enumerate(data_list):
562 for j, var in enumerate(data.keys()):
563 if shape[0] >= 2 and shape[1] >= 2:
564 ax = axs[j // shape[1], j % shape[1]]
565 elif shape[0] < 2 and shape[1] >= 2:
566 ax = axs[j % shape[1]]
567 elif shape[0] >= 2 and shape[1] < 2:
568 ax = axs[j // shape[1]]
569 else:
570 ax = axs
571
572 ax.set_xlabel(labels[var])
573
574 if isinstance(ranges, (int, float)):
575 bounds = auto_range([data[var]], ranges, modify=0.1, symmetric=True)
576 else:
577 bounds = ranges[var]
578
579 ax.hist(data[var], nbins, range=bounds, alpha=0.6, label=data_labels[i], rasterized=True)
580 ax.set_ylabel("Entries")
581
582 sig68 = (np.percentile(data[var], 84) - np.percentile(data[var], 16)) / 2
583 median = np.median(data[var])
584 fits[i][var] = fr"Median = {median:.3}" + var.unit.dname + "\n" + fr"$\sigma_{{68}}$ = {sig68:.3}" + var.unit.dname
585
586 handles, _ = ax.get_legend_handles_labels()
587 if i == len(data_list) - 1:
588 ax.legend(handles=handles, labels=[d[var] for d in fits], framealpha=0, loc='upper left')
589
590 fig.legend(data_labels, loc="upper center", ncol=len(data_labels), bbox_to_anchor=(0.5, 0.95))
591 plt.savefig(format=f"{file_format}", fname=f"{output_dir}/{suptitle.replace(' ', '_')}.{file_format}")
592 plt.close()
593
594
595def plot_resolution(
596 suptitle: str,
597 datadict: dict,
598 data_labels: list,
599 axlabels: dict,
600 xlimit: list,
601 ylimits: dict,
602 bins,
603 fitfunction: callable = None,
604 fitlabel: callable = None,
605 fitrange: list = None,
606 figsize: tuple = (12.0, 7.0),
607 err_override: dict = None,
608 ):
609 """Plot sigma68 vs an x variable (e.g. pseudomomentum) for each observable.
610
611 One subplot is produced per variable in ``datadict``. An optional
612 parametric fit is overlaid on each dataset's sigma68 profile. Saves to
613 ``{output_dir}/{suptitle}{var_names}.{file_format}``.
614
615 Parameters
616 ----------
617 suptitle : str
618 Figure title, also used (with variable names appended) to derive the
619 output file name.
620 datadict : dict
621 Maps variable objects to a list of ``[xdata, ydata]`` pairs, one per
622 dataset. ``xdata`` and ``ydata`` are 1-D arrays of equal length.
623 data_labels : list of str
624 Legend labels, one per dataset.
625 axlabels : dict
626 Maps variable objects to ``[xlabel, ylabel]`` string pairs.
627 xlimit : list of float
628 ``[xmin, xmax]`` range for the x axis.
629 ylimits : dict
630 Maps variable objects to ``[ymin, ymax]`` range for the y axis.
631 bins : int or sequence of float
632 Bin edges or number of bins passed to :func:`to_bins`.
633 fitfunction : callable, optional
634 Function ``f(x, *params)`` to fit to the sigma68 profile. If None,
635 no fit is drawn. Default is None.
636 fitlabel : callable, optional
637 Function ``f(params, errors) -> str`` that produces the fit annotation
638 string. Required when ``fitfunction`` is provided. Default is None.
639 fitrange : list of float, optional
640 ``[xmin, xmax]`` sub-range used for the fit. Defaults to ``xlimit``.
641 figsize : tuple of float, optional
642 Figure size in inches ``(width, height)``. Default is ``(10.0, 6)``.
643 err_override : dict, optional
644 Maps variable objects to a list of per-dataset sigma68 uncertainty
645 arrays that replace the values computed by :func:`to_bins`. Useful for
646 providing external uncertainty estimates. Default is None.
647 """
648 fig, axs = plt.subplots(1, len(datadict))
649 fig.set_size_inches(*figsize)
650 fig.subplots_adjust(wspace=0.3)
651 fig.suptitle(suptitle)
652
653 for i, var in enumerate(datadict.keys()):
654 ax = axs[i] if len(datadict) > 1 else axs
655
656 if len(data_labels) <= 3:
657 color = iter(["b", "r", "g"])
658 else:
659 color = iter(plt.cm.rainbow(np.linspace(0, 1, len(data_labels))))
660
661 for j, _ in enumerate(datadict[var]):
662 c = next(color)
663 xdata = datadict[var][j][0]
664 ydata = datadict[var][j][1]
665 x_vals, _, x_width, _, sig68, sig68_uncert = to_bins(xdata, ydata, bins, xlimit)
666 if err_override:
667 sig68_uncert = err_override[var][j]
668 ax.errorbar(x_vals, sig68, sig68_uncert, x_width, fmt="o", label=data_labels[j], c=c, rasterized=True)
669
670 if fitfunction:
671 if not fitrange:
672 fitrange = xlimit
673 fitmask = [fitrange[0] < val < fitrange[1] for val in x_vals]
674 Fit, Cov = curve_fit(
675 fitfunction,
676 np.array(x_vals)[fitmask],
677 np.array(sig68)[fitmask],
678 sigma=np.array(sig68_uncert)[fitmask],
679 )
680 Err = np.sqrt(np.diag(Cov))
681 fittextlines = fitlabel(Fit, Err).splitlines()
682 fittext = "\n".join(fittextlines) if j == 0 else "\n".join(fittextlines[1:])
683 ax.plot(
684 np.linspace(fitrange[0], fitrange[1], 100),
685 fitfunction(np.linspace(fitrange[0], fitrange[1], 100), Fit[0], Fit[1]),
686 label=fittext, c=c,
687 )
688
689 ax.set_xlim(xlimit[0], xlimit[1])
690 ax.set_ylim(ylimits[var][0], ylimits[var][1])
691 ax.set_xlabel(axlabels[var][0])
692 ax.set_ylabel(axlabels[var][1])
693 ax.legend()
694
695 varsplaintext = f"{[var.plaintext for var in datadict.keys()]}".replace("'", "").replace(" ", "")
696 plt.savefig(format=f"{file_format}",
697 fname=f"{output_dir}/{suptitle.replace(' ', '_')}{varsplaintext}.{file_format}")
698 plt.close()
699