Belle II Software development
plotting.py
1
8"""
9Plotting functions for alignment validation.
10
11Module-level config (set by caller before running):
12 output_dir — directory where plots are saved (default: "validation_plots")
13 file_format — image format (default: "pdf")
14"""
15
16import re
17
18import matplotlib.pyplot as plt
19import numpy as np
20from scipy.optimize import curve_fit
21
22from alignment_validation.utils import auto_range, normal_distribution, to_bins
23
24
25output_dir = "validation_plots"
26
27file_format = "pdf"
28
29# Global typography defaults
30plt.rcParams.update({
31 "axes.labelsize": 14,
32 "axes.titlesize": 14,
33 "xtick.labelsize": 11,
34 "ytick.labelsize": 11,
35 "legend.fontsize": 12,
36 "figure.titlesize": 16,
37})
38
39
40def plot_histogram(
41 data_list: list,
42 data_labels: list,
43 plot_filename: str,
44 xlabel: str,
45 range: int = 96,
46 opacity: float = 0.5,
47 ):
48 """Plot overlaid histograms for each dataset and save to the output directory.
49
50 The histogram range is determined automatically from the data unless an
51 explicit range is provided.
52
53 Parameters
54 ----------
55 data_list : list of array-like
56 One array per dataset to histogram.
57 data_labels : list of str
58 Legend labels, one per dataset.
59 plot_filename : str
60 Output file base name (without extension). The file is saved as
61 ``{output_dir}/{plot_filename}.{file_format}``.
62 xlabel : str
63 x-axis label (supports LaTeX).
64 range : int or float or tuple, optional
65 If a number, the central ``range``% of the union of all datasets is
66 used as the histogram range (via :func:`auto_range`). If a ``(min,
67 max)`` tuple, it is used directly. Default is 96.
68 opacity : float, optional
69 Histogram bar opacity (alpha). Default is 0.5.
70 """
71 if isinstance(range, (int, float)):
72 hist_ranges = auto_range(data_list, range, 0.1)
73 else:
74 hist_ranges = range
75
76 plt.figure(figsize=(5, 4))
77 plt.ticklabel_format(scilimits=(-2, 3), useMathText=True)
78 plt.ylabel("Entries")
79 plt.xlabel(xlabel)
80 for i, data in enumerate(data_list):
81 plt.hist(data, bins=30, range=hist_ranges, alpha=opacity, density=False, label=data_labels[i], rasterized=True)
82
83 plt.legend()
84 plt.savefig(f"{output_dir}/{plot_filename}.{file_format}", bbox_inches="tight")
85 plt.close()
86
87
88def plot_correlations(
89 plot_type: str,
90 x_data_list: list,
91 y_data_list: list,
92 x_data_labels: list,
93 y_data_labels: list,
94 data_labels: list,
95 nbins: int,
96 ranges: tuple = (96, 90),
97 make_2D_hist: bool = True,
98 figsize: tuple = (10.0, 7.5),
99 ):
100 """Plot a grid of per-bin median or resolution profiles for all x/y variable pairs.
101
102 Produces one subplot per (x variable, y variable) combination. Optionally
103 also produces matching 2D histogram plots via
104 :func:`plot_correlations_2D_histograms`.
105
106 Saves to ``{output_dir}/{plot_type}_correlations.{file_format}``.
107
108 Parameters
109 ----------
110 plot_type : str
111 ``'median'`` to plot per-bin medians, or ``'resolution'`` to plot
112 per-bin sigma68 values.
113 x_data_list : list of dict
114 One dict per dataset, mapping variable objects to x arrays.
115 y_data_list : list of dict
116 One dict per dataset, mapping variable objects to y arrays. Must have
117 the same keys across all datasets.
118 x_data_labels : list of str
119 x-axis labels for each x variable, in the same order as the dict keys.
120 y_data_labels : list of str
121 y-axis labels for each y variable, in the same order as the dict keys.
122 data_labels : list of str
123 Legend labels, one per dataset.
124 nbins : int
125 Number of bins for the profile.
126 ranges : tuple, optional
127 ``(x_range, y_range)`` where each element is either a percentile
128 integer/float (auto-computed via :func:`auto_range`) or a pre-computed
129 dict mapping variable objects to ``[min, max]`` limits.
130 Default is ``(96, 90)``.
131 make_2D_hist : bool, optional
132 If True, also call :func:`plot_correlations_2D_histograms` for each
133 dataset. Default is True.
134 figsize : tuple of float, optional
135 Figure size in inches ``(width, height)``. Default is ``(10.0, 7.5)``.
136 """
137 fig, axs = plt.subplots(len(y_data_list[0]), len(x_data_list[0]), sharex="col", sharey="row")
138 fig.set_size_inches(*figsize)
139 if plot_type == 'median':
140 fig.suptitle("Correlations")
141 elif plot_type == 'resolution':
142 fig.suptitle("Resolution relations")
143 else:
144 raise ValueError("plot_type must be 'median' or 'resolution'")
145 fig.subplots_adjust(wspace=0.2, hspace=0.2)
146
147 if isinstance(ranges[0], (int, float)):
148 percentage = ranges[0]
149 ranges = ({}, ranges[1])
150 for var in x_data_list[0].keys():
151 temp = []
152 for i, _ in enumerate(x_data_list):
153 temp.append(auto_range([x_data_list[i][var]], percentage, modify=0.1))
154 ranges[0][var] = [min(temp, key=lambda x: x[0])[0], max(temp, key=lambda x: x[1])[1]]
155 if isinstance(ranges[1], (int, float)):
156 percentage = ranges[1]
157 ranges = (ranges[0], {})
158 for var in y_data_list[0].keys():
159 temp = []
160 for i, _ in enumerate(y_data_list):
161 if plot_type == 'median':
162 temp.append(auto_range([y_data_list[i][var]], percentage, modify=0.1, symmetric=True))
163 elif plot_type == 'resolution':
164 temp.append(
165 (0,
166 np.percentile(y_data_list[i][var], 50 + percentage / 2) -
167 np.percentile(y_data_list[i][var], 50 - percentage / 2)))
168 ranges[1][var] = [min(temp, key=lambda x: x[0])[0], max(temp, key=lambda x: x[1])[1]]
169
170 for y, vary in enumerate(y_data_list[0].keys()):
171 for x, varx in enumerate(x_data_list[0].keys()):
172 for i, _ in enumerate(x_data_list):
173 if y == (len(y_data_list[0].keys()) - 1):
174 axs[y, x].set_xlabel(x_data_labels[x], fontsize=10)
175 if x == 0:
176 axs[y, x].set_ylabel(y_data_labels[y], fontsize=10)
177
178 if plot_type == 'median':
179 x_vals, y_vals, x_width, y_err, _, _ = to_bins(
180 x_data_list[i][varx], y_data_list[i][vary], nbins, ranges[0][varx])
181 elif plot_type == 'resolution':
182 x_vals, _, x_width, _, y_vals, y_err = to_bins(
183 x_data_list[i][varx], y_data_list[i][vary], nbins, ranges[0][varx])
184
185 axs[y, x].tick_params(labelsize=9)
186 with np.errstate(invalid='ignore'):
187 axs[y, x].errorbar(x_vals, y_vals, y_err, x_width, fmt=".", label=data_labels[i])
188
189 axs[y, x].set_xlim(ranges[0][varx][0], ranges[0][varx][1])
190 axs[y, x].set_ylim(ranges[1][vary][0], ranges[1][vary][1])
191
192 fig.legend(data_labels, loc="upper center", ncol=len(data_labels), bbox_to_anchor=(0.5, 0.95))
193 plt.savefig(format=f"{file_format}", fname=f"{output_dir}/{plot_type}_correlations.{file_format}")
194 plt.close()
195
196 if make_2D_hist:
197 for i, _ in enumerate(x_data_list):
198 plot_correlations_2D_histograms(
199 x_data_list[i], y_data_list[i],
200 x_data_labels, y_data_labels,
201 data_labels[i], nbins, ranges, figsize=figsize,
202 )
203
204
205def plot_correlations_2D_histograms(
206 x_data: dict,
207 y_data: dict,
208 x_data_labels: list,
209 y_data_labels: list,
210 data_label: str,
211 nbins: int,
212 ranges: tuple,
213 figsize: tuple = (10.0, 7.5),
214 ):
215 """Plot a grid of 2D histograms showing correlations between all x/y variable pairs.
216
217 Saves to ``{output_dir}/Correlations_2dhist_{data_label}.{file_format}``.
218
219 Parameters
220 ----------
221 x_data : dict
222 Maps variable objects to x arrays for a single dataset.
223 y_data : dict
224 Maps variable objects to y arrays for a single dataset.
225 x_data_labels : list of str
226 x-axis labels for each x variable.
227 y_data_labels : list of str
228 y-axis labels for each y variable.
229 data_label : str
230 Dataset label used in the figure title and file name.
231 nbins : int
232 Number of bins along the x axis of each 2D histogram. The y axis
233 uses a fixed 10 bins.
234 ranges : tuple of (dict, dict)
235 Pre-computed ``(x_ranges, y_ranges)`` dicts mapping variable objects
236 to ``[min, max]`` limits, as produced by :func:`plot_correlations`.
237 figsize : tuple of float, optional
238 Figure size in inches ``(width, height)``. Default is ``(10.0, 7.5)``.
239 """
240 fig, axs = plt.subplots(len(y_data), len(x_data), sharex="col", sharey="row")
241 fig.set_size_inches(*figsize)
242 fig.suptitle(f"Correlations ({data_label})")
243 fig.subplots_adjust(wspace=0.2, hspace=0.2)
244
245 for y, vary in enumerate(y_data.keys()):
246 for x, varx in enumerate(x_data.keys()):
247 axs[y, x].set_xlim(ranges[0][varx])
248 axs[y, x].set_ylim(ranges[1][vary])
249
250 if y == (len(y_data.keys()) - 1):
251 axs[y, x].set_xlabel(x_data_labels[x], fontsize=10)
252 if x == 0:
253 axs[y, x].set_ylabel(y_data_labels[y], fontsize=10)
254
255 with np.errstate(invalid='ignore'):
256 axs[y, x].hist2d(x_data[varx], y_data[vary], [nbins, 10],
257 [axs[y, x].get_xlim(), axs[y, x].get_ylim()], rasterized=True)
258
259 axs[y, x].set_xlim(ranges[0][varx][0], ranges[0][varx][1])
260 axs[y, x].set_ylim(ranges[1][vary][0], ranges[1][vary][1])
261
262 fig.tight_layout()
263 plt.savefig(format=f"{file_format}", fname=f"{output_dir}/Correlations_2dhist_{data_label}.{file_format}")
264 plt.close()
265
266
267def plot_2D_histogram(data: dict, label: str, bins: tuple, phi_var, tan_lambda_var):
268 """Plot a 2D histogram of track occupancy in the phi vs tan(lambda) plane.
269
270 Both tracks are combined into a single histogram. Saves to
271 ``{output_dir}/map_2dhist_{label}.{file_format}``.
272
273 Parameters
274 ----------
275 data : dict
276 Data dictionary mapping branch names to arrays.
277 label : str
278 Dataset label used in the file name.
279 bins : tuple of int
280 ``(n_phi_bins, n_tan_lambda_bins)`` for the 2D histogram.
281 phi_var : TrackVariable
282 Variable used for the y axis (phi).
283 tan_lambda_var : TrackVariable
284 Variable used for the x axis (tan lambda).
285 """
286 plt.figure()
287 plt.title("Detector map histogram")
288 plt.hist2d(
289 np.concatenate((data[tan_lambda_var.name1], data[tan_lambda_var.name2])),
290 np.concatenate((data[phi_var.name1], data[phi_var.name2])),
291 (bins[1], bins[0]),
292 range=((-2, 3), (-np.pi, np.pi)),
293 rasterized=True,
294 )
295 plt.colorbar(label="Events")
296 plt.ylabel(phi_var.latex + phi_var.unit.name, fontsize=12)
297 plt.xlabel(tan_lambda_var.latex + tan_lambda_var.unit.name, fontsize=12)
298 plt.savefig(format=f"{file_format}", fname=f"{output_dir}/map_2dhist_{label}.{file_format}")
299 plt.close()
300
301
302def draw_map(
303 map_type: str,
304 data: dict,
305 label: str,
306 variable,
307 observable_mode: str,
308 bins: tuple,
309 phi_var,
310 tan_lambda_var,
311 vmin: float = None,
312 vmax: float = None,
313 ):
314 """Draw a median or resolution detector map for ``variable`` binned in phi vs tan(lambda).
315
316 The map is built by iterating over phi strips and computing per-bin
317 statistics along tan(lambda) within each strip using :func:`to_bins`. The
318 caller explicitly chooses how track pairs are combined via
319 ``observable_mode``:
320 - ``'delta'``: difference (track1 - track2)
321 - ``'sigma'``: sum (track1 + track2)
322
323 In this codebase, ``sigma`` is intended for dimuon d0 observables, while
324 cosmics maps typically use ``delta``.
325
326 Saves to ``{output_dir}/{map_type}_map_{variable.plaintext}_{label}.{file_format}``.
327
328 Parameters
329 ----------
330 map_type : str
331 ``'median'`` to map the per-bin median, or ``'resolution'`` to map
332 the per-bin sigma68.
333 data : dict
334 Data dictionary mapping branch names to arrays.
335 label : str
336 Dataset label used in the figure title and file name.
337 variable : TrackVariable
338 Observable to map (e.g. ``d`` or ``z``).
339 observable_mode : str
340 Combination mode for track1/track2 values:
341 - ``'delta'`` for difference (track1 - track2)
342 - ``'sigma'`` for sum (track1 + track2)
343 bins : tuple of int
344 ``(n_phi_bins, n_tan_lambda_bins)`` for the map grid.
345 phi_var : TrackVariable
346 Variable providing the phi coordinate (y axis of the map).
347 tan_lambda_var : TrackVariable
348 Variable providing the tan(lambda) coordinate (x axis of the map).
349 vmin : float, optional
350 Minimum of the colour scale. Auto-computed from data if not given.
351 vmax : float, optional
352 Maximum of the colour scale. Auto-computed from data if not given.
353 """
354 if isinstance(bins[0], int):
355 xdim = bins[0]
356 else:
357 xdim = len(bins[0]) - 1
358 if isinstance(bins[1], int):
359 ydim = bins[1]
360 else:
361 ydim = len(bins[1]) - 1
362 map_data = np.zeros((xdim, ydim))
363
364 plt.figure()
365 if map_type == "median":
366 plt.title(f"Median map ({label})")
367 elif map_type == "resolution":
368 plt.title(f"Resolution map ({label})")
369 else:
370 raise ValueError("map_type must be 'median' or 'resolution'")
371 if observable_mode not in {"delta", "sigma"}:
372 raise ValueError("observable_mode must be 'delta' or 'sigma'")
373
374 x_bins = np.histogram_bin_edges(
375 np.concatenate((data[tan_lambda_var.name1], data[tan_lambda_var.name2])), bins[1], (-2, 3))
376 y_bins = np.histogram_bin_edges(
377 np.concatenate((data[phi_var.name1], data[phi_var.name2])), bins[0], (-np.pi, np.pi))
378
379 for i, _ in enumerate(y_bins):
380 if (i + 1) == len(y_bins):
381 continue
382 tracks1 = np.logical_and(y_bins[i] <= data[phi_var.name1], data[phi_var.name1] <= y_bins[i + 1])
383 tracks2 = np.logical_and(y_bins[i] <= data[phi_var.name2], data[phi_var.name2] <= y_bins[i + 1])
384 if observable_mode == "sigma":
385 values = np.concatenate((
386 data[variable.name1][tracks1] + data[variable.name2][tracks1],
387 data[variable.name1][tracks2] + data[variable.name2][tracks2],
388 ))
389 else:
390 values = np.concatenate((
391 data[variable.name1][tracks1] - data[variable.name2][tracks1],
392 data[variable.name1][tracks2] - data[variable.name2][tracks2],
393 ))
394
395 if map_type == "median":
396 _, map_data[i], _, _, _, _ = to_bins(
397 np.concatenate((data[tan_lambda_var.name1][tracks1], data[tan_lambda_var.name2][tracks2])),
398 values, bins[1], (-2, 3),
399 )
400 elif map_type == "resolution":
401 _, _, _, _, map_data[i], _ = to_bins(
402 np.concatenate((data[tan_lambda_var.name1][tracks1], data[tan_lambda_var.name2][tracks2])),
403 values, bins[1], (-2, 3),
404 )
405
406 flat = map_data[~np.isnan(map_data)].flatten()
407 if vmin is None and vmax is None:
408 rawmin, rawmax = auto_range([flat], 98)
409 vmin = rawmin / 2 ** 0.5 * variable.unit.convert
410 vmax = rawmax / 2 ** 0.5 * variable.unit.convert
411 elif vmin is None:
412 rawmin, _ = auto_range([flat], 98)
413 vmin = rawmin / 2 ** 0.5 * variable.unit.convert
414 elif vmax is None:
415 _, rawmax = auto_range([flat], 98)
416 vmax = rawmax / 2 ** 0.5 * variable.unit.convert
417
418 plt.pcolormesh(x_bins, y_bins, map_data / 2 ** 0.5 * variable.unit.convert, vmax=vmax, vmin=vmin, rasterized=True)
419 if map_type == "median":
420 if observable_mode == "sigma":
421 plt.colorbar(label=f"$\\tilde{{}}$($\\Sigma${variable.latex})/$\\sqrt{{2}}$" + variable.unit.dname)
422 else:
423 plt.colorbar(label=f"$\\tilde{{}}$($\\Delta${variable.latex})/$\\sqrt{{2}}$" + variable.unit.dname)
424 elif map_type == "resolution":
425 if observable_mode == "sigma":
426 plt.colorbar(label=f"$\\sigma_{{68}}$($\\Sigma${variable.latex})/$\\sqrt{{2}}$" + variable.unit.dname)
427 else:
428 plt.colorbar(label=f"$\\sigma_{{68}}$($\\Delta${variable.latex})/$\\sqrt{{2}}$" + variable.unit.dname)
429 plt.ylabel(phi_var.latex + phi_var.unit.name, fontsize=12)
430 plt.xlabel(tan_lambda_var.latex + tan_lambda_var.unit.name, fontsize=12)
431 plt.savefig(format=f"{file_format}",
432 fname=f"{output_dir}/{map_type}_map_{variable.plaintext}_{label}.{file_format}")
433 plt.close()
434
435
436def plot_resolutions_hist(
437 suptitle: str,
438 data: dict,
439 labels: dict,
440 nbins: float,
441 ranges=90,
442 vars_to_fit: list = [],
443 shape: tuple = (2, 3),
444 figsize: tuple = (11.0, 8.0),
445 ):
446 """Plot a grid of residual histograms, optionally with Gaussian fits, for each variable.
447
448 Saves to ``{output_dir}/{suptitle}.{file_format}`` (spaces replaced by underscores).
449
450 Parameters
451 ----------
452 suptitle : str
453 Figure title, also used to derive the output file name.
454 data : dict
455 Maps variable objects to 1-D arrays of residual values.
456 labels : dict
457 Maps variable objects to x-axis label strings.
458 nbins : int
459 Number of histogram bins.
460 ranges : int or float or dict, optional
461 If a number, the central ``ranges``% of each variable's data is used
462 as the histogram range (symmetric, via :func:`auto_range`). If a dict,
463 maps variable objects to explicit ``(min, max)`` tuples. Default is 90.
464 vars_to_fit : list, optional
465 Subset of variables for which a Gaussian fit is overlaid and
466 annotated. Default is ``[]`` (no fits).
467 shape : tuple of int, optional
468 ``(nrows, ncols)`` layout of the subplot grid. Default is ``(2, 3)``.
469 figsize : tuple of float, optional
470 Figure size in inches ``(width, height)``. Default is ``(9.0, 6.0)``.
471 """
472 fig, axs = plt.subplots(shape[0], shape[1])
473 fig.suptitle(suptitle, y=0.98)
474 fig.set_size_inches(*figsize)
475 fig.subplots_adjust(wspace=0.4, hspace=0.6, top=0.85)
476
477 for i, var in enumerate(data.keys()):
478 if shape[0] >= 2 and shape[1] >= 2:
479 ax = axs[i // shape[1], i % shape[1]]
480 elif shape[0] < 2 and shape[1] >= 2:
481 ax = axs[i % shape[1]]
482 elif shape[0] >= 2 and shape[1] < 2:
483 ax = axs[i // shape[1]]
484 else:
485 ax = axs
486
487 ax.set_xlabel(labels[var])
488
489 if isinstance(ranges, (int, float)):
490 bounds = auto_range([data[var]], ranges, modify=0.1, symmetric=True)
491 else:
492 bounds = ranges[var]
493
494 nphist = np.histogram(data[var], nbins, range=bounds)
495 x = np.linspace(bounds[0], bounds[1], nbins)
496 ax.hist(data[var], nbins, range=bounds, rasterized=True)
497 ax.set_ylabel("Entries")
498
499 if var in vars_to_fit:
500 try:
501 fit, cov = curve_fit(normal_distribution, x, nphist[0], (5000, 0, 1))
502 err = np.sqrt(np.diag(cov))
503 ax.plot(x, normal_distribution(x, fit[0], fit[1], fit[2]), "k")
504 brackets = re.findall(r'\[(.*?)\]', labels[var])
505 used_units = brackets[-1] if brackets else None
506 fit_parameters = (
507 f"a = {fit[0]:.3}" + r" $\pm$ " + f"{err[0]:.1} " + "\n" +
508 fr"$\mu$ = {fit[1]:.3}" + r" $\pm$ " + f"{err[1]:.1} " + used_units + "\n" +
509 fr"$\sigma$ = {fit[2]:.3}" + r" $\pm$ " + f"{err[2]:.1} " + used_units
510 )
511 ax.text(ax.get_xlim()[0], ax.get_ylim()[1], fit_parameters, size=11, va='bottom')
512 except Exception:
513 print(f"[Warning] Failed to fit {var.plaintext}")
514
515 plt.savefig(format=f"{file_format}", fname=f"{output_dir}/{suptitle.replace(' ', '_')}.{file_format}")
516 plt.close()
517
518
519def plot_resolution_comparison(
520 suptitle: str,
521 data_list: list,
522 data_labels: list,
523 labels: dict,
524 nbins: float,
525 ranges=90,
526 shape: tuple = (2, 3),
527 figsize: tuple = (11.0, 8.0),
528 ):
529 """Overlay residual distributions from multiple datasets and annotate with median and sigma68.
530
531 Each subplot shows histograms from all datasets overlaid, with per-dataset
532 median and sigma68 annotated in the legend. Saves to
533 ``{output_dir}/{suptitle}.{file_format}`` (spaces replaced by underscores).
534
535 Parameters
536 ----------
537 suptitle : str
538 Figure title, also used to derive the output file name.
539 data_list : list of dict
540 One dict per dataset, mapping variable objects to residual arrays.
541 data_labels : list of str
542 Legend labels, one per dataset.
543 labels : dict
544 Maps variable objects to x-axis label strings.
545 nbins : int
546 Number of histogram bins.
547 ranges : int or float or dict, optional
548 If a number, the central ``ranges``% of each variable's data is used
549 as the histogram range (symmetric, via :func:`auto_range`). If a dict,
550 maps variable objects to explicit ``(min, max)`` tuples. Default is 90.
551 shape : tuple of int, optional
552 ``(nrows, ncols)`` layout of the subplot grid. Default is ``(2, 3)``.
553 figsize : tuple of float, optional
554 Figure size in inches ``(width, height)``. Default is ``(9.0, 6.0)``.
555 """
556 fig, axs = plt.subplots(shape[0], shape[1])
557 fig.suptitle(suptitle, y=0.98)
558 fig.subplots_adjust(wspace=0.4, hspace=0.6, top=0.85)
559 fig.set_size_inches(*figsize)
560 fits = [{} for _ in data_list]
561
562 for i, data in enumerate(data_list):
563 for j, var in enumerate(data.keys()):
564 if shape[0] >= 2 and shape[1] >= 2:
565 ax = axs[j // shape[1], j % shape[1]]
566 elif shape[0] < 2 and shape[1] >= 2:
567 ax = axs[j % shape[1]]
568 elif shape[0] >= 2 and shape[1] < 2:
569 ax = axs[j // shape[1]]
570 else:
571 ax = axs
572
573 ax.set_xlabel(labels[var])
574
575 if isinstance(ranges, (int, float)):
576 bounds = auto_range([data[var]], ranges, modify=0.1, symmetric=True)
577 else:
578 bounds = ranges[var]
579
580 ax.hist(data[var], nbins, range=bounds, alpha=0.6, label=data_labels[i], rasterized=True)
581 ax.set_ylabel("Entries")
582
583 sig68 = (np.percentile(data[var], 84) - np.percentile(data[var], 16)) / 2
584 median = np.median(data[var])
585 fits[i][var] = fr"Median = {median:.3}" + var.unit.dname + "\n" + fr"$\sigma_{{68}}$ = {sig68:.3}" + var.unit.dname
586
587 handles, _ = ax.get_legend_handles_labels()
588 if i == len(data_list) - 1:
589 ax.legend(handles=handles, labels=[d[var] for d in fits], framealpha=0, loc='upper left')
590
591 fig.legend(data_labels, loc="upper center", ncol=len(data_labels), bbox_to_anchor=(0.5, 0.95))
592 plt.savefig(format=f"{file_format}", fname=f"{output_dir}/{suptitle.replace(' ', '_')}.{file_format}")
593 plt.close()
594
595
596def plot_resolution(
597 suptitle: str,
598 datadict: dict,
599 data_labels: list,
600 axlabels: dict,
601 xlimit: list,
602 ylimits: dict,
603 bins,
604 fitfunction: callable = None,
605 fitlabel: callable = None,
606 fitrange: list = None,
607 figsize: tuple = (12.0, 7.0),
608 err_override: dict = None,
609 ):
610 """Plot sigma68 vs an x variable (e.g. pseudomomentum) for each observable.
611
612 One subplot is produced per variable in ``datadict``. An optional
613 parametric fit is overlaid on each dataset's sigma68 profile. Saves to
614 ``{output_dir}/{suptitle}{var_names}.{file_format}``.
615
616 Parameters
617 ----------
618 suptitle : str
619 Figure title, also used (with variable names appended) to derive the
620 output file name.
621 datadict : dict
622 Maps variable objects to a list of ``[xdata, ydata]`` pairs, one per
623 dataset. ``xdata`` and ``ydata`` are 1-D arrays of equal length.
624 data_labels : list of str
625 Legend labels, one per dataset.
626 axlabels : dict
627 Maps variable objects to ``[xlabel, ylabel]`` string pairs.
628 xlimit : list of float
629 ``[xmin, xmax]`` range for the x axis.
630 ylimits : dict
631 Maps variable objects to ``[ymin, ymax]`` range for the y axis.
632 bins : int or sequence of float
633 Bin edges or number of bins passed to :func:`to_bins`.
634 fitfunction : callable, optional
635 Function ``f(x, *params)`` to fit to the sigma68 profile. If None,
636 no fit is drawn. Default is None.
637 fitlabel : callable, optional
638 Function ``f(params, errors) -> str`` that produces the fit annotation
639 string. Required when ``fitfunction`` is provided. Default is None.
640 fitrange : list of float, optional
641 ``[xmin, xmax]`` sub-range used for the fit. Defaults to ``xlimit``.
642 figsize : tuple of float, optional
643 Figure size in inches ``(width, height)``. Default is ``(10.0, 6)``.
644 err_override : dict, optional
645 Maps variable objects to a list of per-dataset sigma68 uncertainty
646 arrays that replace the values computed by :func:`to_bins`. Useful for
647 providing external uncertainty estimates. Default is None.
648 """
649 fig, axs = plt.subplots(1, len(datadict))
650 fig.set_size_inches(*figsize)
651 fig.subplots_adjust(wspace=0.3)
652 fig.suptitle(suptitle)
653
654 for i, var in enumerate(datadict.keys()):
655 ax = axs[i] if len(datadict) > 1 else axs
656
657 if len(data_labels) <= 3:
658 color = iter(["b", "r", "g"])
659 else:
660 color = iter(plt.cm.rainbow(np.linspace(0, 1, len(data_labels))))
661
662 for j, _ in enumerate(datadict[var]):
663 c = next(color)
664 xdata = datadict[var][j][0]
665 ydata = datadict[var][j][1]
666 x_vals, _, x_width, _, sig68, sig68_uncert = to_bins(xdata, ydata, bins, xlimit)
667 if err_override:
668 sig68_uncert = err_override[var][j]
669 ax.errorbar(x_vals, sig68, sig68_uncert, x_width, fmt="o", label=data_labels[j], c=c, rasterized=True)
670
671 if fitfunction:
672 if not fitrange:
673 fitrange = xlimit
674 fitmask = [fitrange[0] < val < fitrange[1] for val in x_vals]
675 Fit, Cov = curve_fit(
676 fitfunction,
677 np.array(x_vals)[fitmask],
678 np.array(sig68)[fitmask],
679 sigma=np.array(sig68_uncert)[fitmask],
680 )
681 Err = np.sqrt(np.diag(Cov))
682 fittextlines = fitlabel(Fit, Err).splitlines()
683 fittext = "\n".join(fittextlines) if j == 0 else "\n".join(fittextlines[1:])
684 ax.plot(
685 np.linspace(fitrange[0], fitrange[1], 100),
686 fitfunction(np.linspace(fitrange[0], fitrange[1], 100), Fit[0], Fit[1]),
687 label=fittext, c=c,
688 )
689
690 ax.set_xlim(xlimit[0], xlimit[1])
691 ax.set_ylim(ylimits[var][0], ylimits[var][1])
692 ax.set_xlabel(axlabels[var][0])
693 ax.set_ylabel(axlabels[var][1])
694 ax.legend()
695
696 varsplaintext = f"{[var.plaintext for var in datadict.keys()]}".replace("'", "").replace(" ", "")
697 plt.savefig(format=f"{file_format}",
698 fname=f"{output_dir}/{suptitle.replace(' ', '_')}{varsplaintext}.{file_format}")
699 plt.close()