Belle II Software  release-06-02-00
svd_time.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3 
4 
11 
12 from pathlib import Path
13 import sys
14 
15 import pandas as pd
16 import seaborn as sns
17 import matplotlib
18 import matplotlib.pyplot as plt
19 import matplotlib.ticker as ticker
20 
21 from prompt import ValidationSettings
22 import svd.validation_utils as vu
23 
24 import ROOT as r
25 r.PyConfig.IgnoreCommandLineOptions = True
26 
27 matplotlib.use('Agg')
28 plt.style.use("belle2")
29 
30 
31 settings = ValidationSettings(name="caf_svd_time",
32  description=__doc__,
33  download_files=[],
34  expert_config=None)
35 
36 
37 def progress(count, total):
38  bar_len = 60
39  filled_len = int(round(bar_len * count / total))
40  percents = round(100 * count / total, 1)
41  bar = '=' * filled_len + '-' * (bar_len - filled_len)
42  sys.stdout.write(f'[{bar}] {percents}%\r')
43  sys.stdout.flush()
44 
45 
46 def run_validation(job_path, input_data_path=None, **kwargs):
47  '''job_path will be replaced with path/to/calibration_results
48  input_data_path will be replaced with path/to/data_path used for calibration
49  e.g. /group/belle2/dataprod/Data/PromptSkim/'''
50 
51  collector_output_dir = Path(job_path) / 'SVDTimeValidation/0/collector_output/default/'
52  output_dir = Path(kwargs.get('output_dir', 'SVDTimeValidation_output'))
53  plots_per_run = output_dir / 'runs'
54 
55  plots_per_run.mkdir(parents=True, exist_ok=True)
56 
57  files = list(collector_output_dir.glob('**/CollectorOutput.root'))
58 
59  agreements = {algo: {} for algo in vu.time_algorithms}
60  precisions = {algo: {} for algo in vu.time_algorithms}
61  discriminations = {algo: {} for algo in vu.time_algorithms}
62  entries_onTracks = {algo: {} for algo in vu.time_algorithms}
63  entries_eventT0 = {algo: {} for algo in vu.time_algorithms}
64 
65  roc_U = {algo: {} for algo in vu.time_algorithms}
66  roc_V = {algo: {} for algo in vu.time_algorithms}
67 
68  num_files = len(files)
69  print(f'Looping over {num_files} files')
70  progress(0, num_files)
71  for count, in_file_name in enumerate(files):
72 
73  in_file = r.TFile(str(in_file_name))
74 
75  for algo in vu.time_algorithms:
76 
77  histos, exp, run = vu.get_histos(in_file, algo)
78 
79  if histos is None:
80  print(f'Skipping file {in_file_name} for {algo}')
81  continue
82 
83  # if some histogram is empty (too little stat) do not crash but skip that file for that calibration
84  try:
85  entries_eventT0_ = histos['eventT0'].GetEntries()
86  if run not in entries_eventT0[algo] or entries_eventT0_ > entries_eventT0[algo][run]:
87  agreements[algo][run] = {key: vu.get_agreament(histos['eventT0'], h_diff)
88  for key, h_diff in histos['diff'].items()}
89  precisions[algo][run] = {key: vu.get_precision(h_diff)
90  for key, h_diff in histos['diff'].items()}
91  discriminations[algo][run] = {key: vu.get_roc_auc(histos['onTracks'][key], histos['offTracks'][key])
92  for key in histos['onTracks']}
93  entries_onTracks[algo][run] = {key: val.GetEntries() for key, val in histos['onTracks'].items()}
94  entries_eventT0[algo][run] = entries_eventT0_
95 
96  vu.make_combined_plot('*U', histos,
97  title=f'exp {exp} run {run} U {algo}')
98  plt.savefig(plots_per_run / f'{exp}_{run}_U_{algo}.pdf')
99  plt.close()
100 
101  vu.make_combined_plot('*V', histos,
102  title=f'exp {exp} run {run} V {algo}')
103  plt.savefig(plots_per_run / f'{exp}_{run}_V_{algo}.pdf')
104  plt.close()
105 
106  roc_U[algo][run] = vu.make_roc(vu.get_combined(histos['onTracks'], '*U'),
107  vu.get_combined(histos['offTracks'], '*U'))
108  roc_V[algo][run] = vu.make_roc(vu.get_combined(histos['onTracks'], '*V'),
109  vu.get_combined(histos['offTracks'], '*V'))
110  except AttributeError:
111  print(f'Skipping file {in_file_name} for {algo}')
112  continue
113 
114  in_file.Close()
115 
116  # Show the progress
117  progress(count+1, num_files)
118 
119  print()
120 
121  dd = {}
122  runs = sorted(agreements[vu.time_algorithms[0]])
123  dd['run'] = sum([[i]*len(vu.names_sides) for i in runs], [])
124  dd['name'] = vu.names_sides*len(runs)
125  dd['side'] = [i[-1] for i in dd['name']]
126 
127  for algo in vu.time_algorithms:
128  dd[f'agreement_{algo}'] = [agreements[algo][run][side] for run, side in zip(dd['run'], dd['name'])]
129  dd[f'precision_{algo}'] = [precisions[algo][run][side] for run, side in zip(dd['run'], dd['name'])]
130  dd[f'discrimination_{algo}'] = [discriminations[algo][run][side] for run, side in zip(dd['run'], dd['name'])]
131  dd[f'entries_onTracks_{algo}'] = [entries_onTracks[algo][run][side] for run, side in zip(dd['run'], dd['name'])]
132  dd[f'entries_eventT0_{algo}'] = [entries_eventT0[algo][run] for run, side in zip(dd['run'], dd['name'])]
133 
134  # Make ROC plots
135  for run in runs:
136  plt.figure()
137  plt.plot(*roc_U['CoG6'][run], 'k-', label='CoG6 U')
138  plt.plot(*roc_V['CoG6'][run], 'k:', label='CoG6 V')
139  plt.plot(*roc_U['CoG3'][run], 'b-', label='CoG3 U')
140  plt.plot(*roc_V['CoG3'][run], 'b:', label='CoG3 V')
141  plt.plot(*roc_U['ELS3'][run], 'r-', label='ELS3 U')
142  plt.plot(*roc_V['ELS3'][run], 'r:', label='ELS3 V')
143  plt.legend(loc='lower left')
144  plt.xlabel('sgn efficiency')
145  plt.ylabel('bkg rejection')
146  plt.title(f'ROC run {run}')
147  plt.xlim((0, 1))
148  plt.ylim((0, 1))
149  plt.tight_layout()
150  plt.savefig(plots_per_run / f'ROC_{run}.pdf')
151  plt.close()
152 
153  df = pd.DataFrame(dd)
154  df.to_pickle(output_dir / 'df.pkl')
155 
156  # df = pd.read_pickle('df.pkl')
157 
158  print('Making combined plots')
159 
160  for algo in vu.time_algorithms:
161  plt.figure(figsize=(6.4*max(2, num_files/30), 4.8*2))
162  ax = sns.violinplot(x='run', y=f'agreement_{algo}', hue='side', data=df, split=True)
163  ax.set_ylim([-2, 2])
164  ax.xaxis.set_minor_locator(ticker.NullLocator())
165  plt.axhline(0, color='black', linestyle='--')
166  plt.axhline(0.5, color='black', linestyle=':')
167  plt.axhline(-0.5, color='black', linestyle=':')
168  plt.setp(ax.get_xticklabels(), rotation=90)
169  plt.tight_layout()
170  plt.savefig(output_dir / f'agreement_{algo}.pdf')
171  plt.close()
172 
173  plt.figure(figsize=(6.4*max(2, num_files/30), 4.8*2))
174  ax = sns.violinplot(x='run', y=f'precision_{algo}', hue='side', data=df, split=True)
175  ax.set_ylim([0, 50])
176  ax.xaxis.set_minor_locator(ticker.NullLocator())
177  plt.axhline(10, color='black', linestyle=':')
178  plt.axhline(20, color='black', linestyle=':')
179  plt.setp(ax.get_xticklabels(), rotation=90)
180  plt.tight_layout()
181  plt.savefig(output_dir / f'precision_{algo}.pdf')
182  plt.close()
183 
184  plt.figure(figsize=(6.4*max(2, num_files/30), 4.8*2))
185  ax = sns.violinplot(x='run', y=f'discrimination_{algo}', hue='side', data=df, split=True)
186  ax.set_ylim([0.5, 1])
187  ax.xaxis.set_minor_locator(ticker.NullLocator())
188  plt.axhline(0.8, color='black', linestyle=':')
189  plt.axhline(0.9, color='black', linestyle=':')
190  plt.setp(ax.get_xticklabels(), rotation=90)
191  plt.tight_layout()
192  plt.savefig(output_dir / f'discrimination_{algo}.pdf')
193  plt.close()
194 
195  plt.figure(figsize=(6.4*max(2, num_files/30), 4.8*2))
196  ax = sns.violinplot(x='run', y=f'entries_onTracks_{algo}', hue='side', data=df, split=True, cut=0)
197  ax.xaxis.set_minor_locator(ticker.NullLocator())
198  plt.setp(ax.get_xticklabels(), rotation=90)
199  plt.tight_layout()
200  plt.savefig(output_dir / f'entries_onTracks_{algo}.pdf')
201  plt.close()
202 
203  plt.figure(figsize=(6.4*max(2, num_files/30), 4.8*2))
204  ax = sns.violinplot(x='run', y=f'entries_eventT0_{algo}', hue='side', data=df, split=True)
205  ax.xaxis.set_minor_locator(ticker.NullLocator())
206  plt.setp(ax.get_xticklabels(), rotation=90)
207  plt.tight_layout()
208  plt.savefig(output_dir / f'entries_eventT0_{algo}.pdf')
209  plt.close()
210 
211 
212 if __name__ == '__main__':
213 
214  import argparse
215  parser = argparse.ArgumentParser(description=__doc__,
216  formatter_class=argparse.RawTextHelpFormatter)
217 
218  # b2val-prompt-run wants to pass to the script also input_data_path
219  # and requested_iov. As they are not required by this validation I just accept
220  # them together with calibration_results_dir and then ignore them
221  parser.add_argument('calibration_results_dir',
222  help='The directory that contains the collector outputs',
223  nargs='+')
224 
225  parser.add_argument('-o', '--output_dir',
226  help='The directory where all the output will be saved',
227  default='SVDTimeValidation_output')
228  args = parser.parse_args()
229 
230  run_validation(args.calibration_results_dir[0], output_dir=args.output_dir)