Belle II Software  release-05-01-25
svd_time.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3 
4 from pathlib import Path
5 import sys
6 
7 import pandas as pd
8 import seaborn as sns
9 import matplotlib
10 import matplotlib.pyplot as plt
11 import matplotlib.ticker as ticker
12 
13 from prompt import ValidationSettings
14 import svd.validation_utils as vu
15 
16 import ROOT as r
17 r.PyConfig.IgnoreCommandLineOptions = True
18 
19 matplotlib.use('Agg')
20 plt.style.use("belle2")
21 
22 
23 settings = ValidationSettings(name="caf_svd_time",
24  description=__doc__,
25  download_files=[],
26  expert_config=None)
27 
28 
29 def progress(count, total):
30  bar_len = 60
31  filled_len = int(round(bar_len * count / total))
32  percents = round(100 * count / total, 1)
33  bar = '=' * filled_len + '-' * (bar_len - filled_len)
34  sys.stdout.write(f'[{bar}] {percents}%\r')
35  sys.stdout.flush()
36 
37 
38 def run_validation(job_path, input_data_path=None, **kwargs):
39  '''job_path will be replaced with path/to/calibration_results
40  input_data_path will be replaced with path/to/data_path used for calibration
41  e.g. /group/belle2/dataprod/Data/PromptSkim/'''
42 
43  collector_output_dir = Path(job_path) / 'SVDTimeValidation/0/collector_output/default/'
44  output_dir = Path(kwargs.get('output_dir', 'SVDTimeValidation_output'))
45  plots_per_run = output_dir / 'runs'
46 
47  plots_per_run.mkdir(parents=True, exist_ok=True)
48 
49  files = list(collector_output_dir.glob('**/CollectorOutput.root'))
50 
51  agreements = {algo: {} for algo in vu.time_algorithms}
52  precisions = {algo: {} for algo in vu.time_algorithms}
53  discriminations = {algo: {} for algo in vu.time_algorithms}
54  entries_onTracks = {algo: {} for algo in vu.time_algorithms}
55  entries_eventT0 = {algo: {} for algo in vu.time_algorithms}
56 
57  roc_U = {algo: {} for algo in vu.time_algorithms}
58  roc_V = {algo: {} for algo in vu.time_algorithms}
59 
60  num_files = len(files)
61  print(f'Looping over {num_files} files')
62  progress(0, num_files)
63  for count, in_file_name in enumerate(files):
64 
65  in_file = r.TFile(str(in_file_name))
66 
67  for algo in vu.time_algorithms:
68 
69  histos, exp, run = vu.get_histos(in_file, algo)
70 
71  entries_eventT0_ = histos['eventT0'].GetEntries()
72  if run not in entries_eventT0[algo] or entries_eventT0_ > entries_eventT0[algo][run]:
73  agreements[algo][run] = {key: vu.get_agreament(histos['eventT0'], h_diff)
74  for key, h_diff in histos['diff'].items()}
75  precisions[algo][run] = {key: vu.get_precision(h_diff)
76  for key, h_diff in histos['diff'].items()}
77  discriminations[algo][run] = {key: vu.get_roc_auc(histos['onTracks'][key], histos['offTracks'][key])
78  for key in histos['onTracks']}
79  entries_onTracks[algo][run] = {key: val.GetEntries() for key, val in histos['onTracks'].items()}
80  entries_eventT0[algo][run] = entries_eventT0_
81 
82  vu.make_combined_plot('*U', histos,
83  title=f'exp {exp} run {run} U {algo}')
84  plt.savefig(plots_per_run / f'{exp}_{run}_U_{algo}.pdf')
85  plt.close()
86 
87  vu.make_combined_plot('*V', histos,
88  title=f'exp {exp} run {run} V {algo}')
89  plt.savefig(plots_per_run / f'{exp}_{run}_V_{algo}.pdf')
90  plt.close()
91 
92  roc_U[algo][run] = vu.make_roc(vu.get_combined(histos['onTracks'], '*U'),
93  vu.get_combined(histos['offTracks'], '*U'))
94  roc_V[algo][run] = vu.make_roc(vu.get_combined(histos['onTracks'], '*V'),
95  vu.get_combined(histos['offTracks'], '*V'))
96 
97  in_file.Close()
98 
99  # Show the progress
100  progress(count+1, num_files)
101 
102  print()
103 
104  dd = {}
105  runs = sorted(agreements[vu.time_algorithms[0]])
106  dd['run'] = sum([[i]*len(vu.names_sides) for i in runs], [])
107  dd['name'] = vu.names_sides*len(runs)
108  dd['side'] = [i[-1] for i in dd['name']]
109 
110  for algo in vu.time_algorithms:
111  dd[f'agreement_{algo}'] = [agreements[algo][run][side] for run, side in zip(dd['run'], dd['name'])]
112  dd[f'precision_{algo}'] = [precisions[algo][run][side] for run, side in zip(dd['run'], dd['name'])]
113  dd[f'discrimination_{algo}'] = [discriminations[algo][run][side] for run, side in zip(dd['run'], dd['name'])]
114  dd[f'entries_onTracks_{algo}'] = [entries_onTracks[algo][run][side] for run, side in zip(dd['run'], dd['name'])]
115  dd[f'entries_eventT0_{algo}'] = [entries_eventT0[algo][run] for run, side in zip(dd['run'], dd['name'])]
116 
117  # Make ROC plots
118  for run in runs:
119  plt.figure()
120  plt.plot(*roc_U['CoG6'][run], 'k-', label='CoG6 U')
121  plt.plot(*roc_V['CoG6'][run], 'k:', label='CoG6 V')
122  plt.plot(*roc_U['CoG3'][run], 'b-', label='CoG3 U')
123  plt.plot(*roc_V['CoG3'][run], 'b:', label='CoG3 V')
124  plt.plot(*roc_U['ELS3'][run], 'r-', label='ELS3 U')
125  plt.plot(*roc_V['ELS3'][run], 'r:', label='ELS3 V')
126  plt.legend(loc='lower left')
127  plt.xlabel('sgn efficiency')
128  plt.ylabel('bkg rejection')
129  plt.title(f'ROC run {run}')
130  plt.xlim((0, 1))
131  plt.ylim((0, 1))
132  plt.tight_layout()
133  plt.savefig(plots_per_run / f'ROC_{run}.pdf')
134  plt.close()
135 
136  df = pd.DataFrame(dd)
137  df.to_pickle(output_dir / 'df.pkl')
138 
139  # df = pd.read_pickle('df.pkl')
140 
141  print('Making combined plots')
142 
143  for algo in vu.time_algorithms:
144  plt.figure(figsize=(6.4*max(2, num_files/30), 4.8*2))
145  ax = sns.violinplot(x='run', y=f'agreement_{algo}', hue='side', data=df, split=True)
146  ax.set_ylim([-2, 2])
147  ax.xaxis.set_minor_locator(ticker.NullLocator())
148  plt.axhline(0, color='black', linestyle='--')
149  plt.axhline(0.5, color='black', linestyle=':')
150  plt.axhline(-0.5, color='black', linestyle=':')
151  plt.setp(ax.get_xticklabels(), rotation=90)
152  plt.tight_layout()
153  plt.savefig(output_dir / f'agreement_{algo}.pdf')
154  plt.close()
155 
156  plt.figure(figsize=(6.4*max(2, num_files/30), 4.8*2))
157  ax = sns.violinplot(x='run', y=f'precision_{algo}', hue='side', data=df, split=True)
158  ax.set_ylim([0, 50])
159  ax.xaxis.set_minor_locator(ticker.NullLocator())
160  plt.axhline(10, color='black', linestyle=':')
161  plt.axhline(20, color='black', linestyle=':')
162  plt.setp(ax.get_xticklabels(), rotation=90)
163  plt.tight_layout()
164  plt.savefig(output_dir / f'precision_{algo}.pdf')
165  plt.close()
166 
167  plt.figure(figsize=(6.4*max(2, num_files/30), 4.8*2))
168  ax = sns.violinplot(x='run', y=f'discrimination_{algo}', hue='side', data=df, split=True)
169  ax.set_ylim([0.5, 1])
170  ax.xaxis.set_minor_locator(ticker.NullLocator())
171  plt.axhline(0.8, color='black', linestyle=':')
172  plt.axhline(0.9, color='black', linestyle=':')
173  plt.setp(ax.get_xticklabels(), rotation=90)
174  plt.tight_layout()
175  plt.savefig(output_dir / f'discrimination_{algo}.pdf')
176  plt.close()
177 
178  plt.figure(figsize=(6.4*max(2, num_files/30), 4.8*2))
179  ax = sns.violinplot(x='run', y=f'entries_onTracks_{algo}', hue='side', data=df, split=True)
180  ax.xaxis.set_minor_locator(ticker.NullLocator())
181  plt.setp(ax.get_xticklabels(), rotation=90)
182  plt.tight_layout()
183  plt.savefig(output_dir / f'entries_onTracks_{algo}.pdf')
184  plt.close()
185 
186  plt.figure(figsize=(6.4*max(2, num_files/30), 4.8*2))
187  ax = sns.violinplot(x='run', y=f'entries_eventT0_{algo}', hue='side', data=df, split=True)
188  ax.xaxis.set_minor_locator(ticker.NullLocator())
189  plt.setp(ax.get_xticklabels(), rotation=90)
190  plt.tight_layout()
191  plt.savefig(output_dir / f'entries_eventT0_{algo}.pdf')
192  plt.close()
193 
194 
195 if __name__ == '__main__':
196 
197  import argparse
198  parser = argparse.ArgumentParser(description=__doc__,
199  formatter_class=argparse.RawTextHelpFormatter)
200 
201  # b2val-prompt-run wants to pass to the script also input_data_path
202  # and requested_iov. As they are not required by this validation I just accept
203  # them together with calibration_results_dir and then ignore them
204  parser.add_argument('calibration_results_dir',
205  help='The directory that contains the collector outputs',
206  nargs='+')
207 
208  parser.add_argument('-o', '--output_dir',
209  help='The directory where all the output will be saved',
210  default='SVDTimeValidation_output')
211  args = parser.parse_args()
212 
213  run_validation(args.calibration_results_dir[0], output_dir=args.output_dir)
svd.validation_utils
Definition: validation_utils.py:1