Belle II Software development
svd_time.py
1#!/usr/bin/env python3
2
3
10
11from pathlib import Path
12
13import pandas as pd
14import seaborn as sns
15import matplotlib
16import matplotlib.pyplot as plt
17import matplotlib.ticker as ticker
18import re
19
20from prompt import ValidationSettings
21import svd.validation_utils as vu
22
23import ROOT as r
24r.PyConfig.IgnoreCommandLineOptions = True
25r.gROOT.SetBatch()
26
27matplotlib.use('Agg')
28plt.style.use("belle2")
29
30
31settings = ValidationSettings(name="caf_svd_time",
32 description=__doc__,
33 download_files=[],
34 expert_config=None)
35
36
37def run_validation(job_path, input_data_path=None, **kwargs):
38 '''job_path will be replaced with path/to/calibration_results
39 input_data_path will be replaced with path/to/data_path used for calibration
40 e.g. /group/belle2/dataprod/Data/PromptSkim/'''
41
42 collector_output_dir = Path(job_path) / 'SVDTimeValidation/0/collector_output/default/'
43 output_dir = Path(kwargs.get('output_dir', 'SVDTimeValidation_output'))
44 shift_detailed = kwargs.get('shift_detailed', False)
45 plots_per_run = output_dir / 'runs'
46
47 plots_per_run.mkdir(parents=True, exist_ok=True)
48
49 files = list(collector_output_dir.glob('**/CollectorOutput.root'))
50
51 agreements = {algo: {} for algo in vu.time_algorithms}
52 precisions = {algo: {} for algo in vu.time_algorithms}
53 discriminations = {algo: {} for algo in vu.time_algorithms}
54 shift_agreements = {algo: {} for algo in vu.time_algorithms}
55 entries_onTracks = {algo: {} for algo in vu.time_algorithms}
56 entries_eventT0 = {algo: {} for algo in vu.time_algorithms}
57
58 roc_U = {algo: {} for algo in vu.time_algorithms}
59 roc_V = {algo: {} for algo in vu.time_algorithms}
60
61 CollectorHistograms = vu.get_merged_collector_histograms(files)
62
63 max_total_run = 0
64 total_item = 0
65 for algo in CollectorHistograms:
66 for exp in CollectorHistograms[algo]:
67 nRun = len(CollectorHistograms[algo][exp])
68 total_item += nRun
69 if nRun > max_total_run:
70 max_total_run = nRun
71 total_length = max_total_run * len(vu.time_algorithms)
72
73 print(f'Looping over {total_item} items')
74 count = 0
75 vu.progress(0, total_item)
76
77 shift_histos = {}
78 shift_histos_merged_over_ladder = {}
79
80 for algo in CollectorHistograms:
81 shift_histos[algo] = {}
82 shift_histos_merged_over_ladder[algo] = {}
83 for exp in CollectorHistograms[algo]:
84 for run in CollectorHistograms[algo][exp]:
85 # print(f"working with : algo {algo} exp {exp} run {run}")
86
87 histos = vu.get_histos(CollectorHistograms[algo][exp][run])
88
89 if histos is None:
90 print(f'Skipping file algo {algo} exp {exp} run {run}')
91 continue
92
93 # if some histogram is empty (too little stat) do not crash but skip that file for that calibration
94 try:
95 entries_eventT0_ = histos['eventT0'].GetEntries()
96 if run not in entries_eventT0[algo] or entries_eventT0_ > entries_eventT0[algo][run]:
97 agreements[algo][run] = {key: vu.get_agreement(histos['eventT0'], h_diff)
98 for key, h_diff in histos['diff'].items()}
99 precisions[algo][run] = {key: vu.get_precision(h_diff)
100 for key, h_diff in histos['diff'].items()}
101 discriminations[algo][run] = {key: vu.get_roc_auc(histos['onTracks'][key], histos['offTracks'][key])
102 for key in histos['onTracks']}
103 shift_agreements[algo][run] = {key: vu.get_shift_agreement(hShift)
104 for key, hShift in histos['timeShifter'].items()}
105 entries_onTracks[algo][run] = {key: val.GetEntries() for key, val in histos['onTracks'].items()}
106 entries_eventT0[algo][run] = entries_eventT0_
107
108 if shift_detailed:
109 for key, hShift in histos['timeShifter'].items():
110 if key in shift_histos[algo]:
111 shift_histos[algo][key].Add(hShift)
112 else:
113 shift_histos[algo][key] = hShift.Clone()
114 shift_histos[algo][key].SetDirectory(0)
115 sensor_id = re.findall(r'\d+', key) + [key[-1]]
116 keyGroup = f'L{sensor_id[0]}S{sensor_id[2]}{sensor_id[3]}'
117 if keyGroup in shift_histos_merged_over_ladder[algo]:
118 shift_histos_merged_over_ladder[algo][keyGroup].Add(hShift)
119 else:
120 shift_histos_merged_over_ladder[algo][keyGroup] = hShift.Clone()
121 shift_histos_merged_over_ladder[algo][keyGroup].SetDirectory(0)
122
123 vu.make_combined_plot('*U', histos,
124 title=f'exp {exp} run {run} U {algo}')
125 plt.savefig(plots_per_run / f'{exp}_{run}_U_{algo}.pdf')
126 plt.close()
127
128 vu.make_combined_plot('*V', histos,
129 title=f'exp {exp} run {run} V {algo}')
130 plt.savefig(plots_per_run / f'{exp}_{run}_V_{algo}.pdf')
131 plt.close()
132
133 roc_U[algo][run] = vu.make_roc(vu.get_combined(histos['onTracks'], '*U'),
134 vu.get_combined(histos['offTracks'], '*U'))
135 roc_V[algo][run] = vu.make_roc(vu.get_combined(histos['onTracks'], '*V'),
136 vu.get_combined(histos['offTracks'], '*V'))
137 except AttributeError:
138 print(f'Skipping file algo {algo} exp {exp} run {run}')
139 continue
140
141 # Free-up memory manually as I used `SetDirectory(0)`
142 histos['eventT0'].Delete()
143 del histos['eventT0']
144 for histo_dict in histos.values():
145 for hh in histo_dict.values():
146 hh.Delete()
147 del histos
148
149 for key, hh in CollectorHistograms[algo][exp][run].items():
150 if key != 'hEventT0':
151 hh.Delete()
152
153 vu.progress(count + 1, total_item)
154 count += 1
155
156 print()
157
158 if shift_detailed:
159 for algo, KeyHisto in shift_histos.items():
160 c2 = r.TCanvas("c2", "c2", 640, 480)
161 outPDF = f"{output_dir}/shift_histograms_{algo}.pdf"
162 c2.Print(outPDF + "[")
163 onePad = r.TPad("onePad", "onePad", 0, 0, 1, 1)
164 onePad.SetMargin(0.1, 0.2, 0.1, 0.1)
165 onePad.SetNumber(1)
166 onePad.Draw()
167 onePad.cd()
168 hShiftHisto = vu.get_shift_plot(shift_histos_merged_over_ladder[algo])
169 hShiftHisto.Draw('COLZ')
170 c2.Print(outPDF, "Title:" + hShiftHisto.GetName())
171
172 c1 = r.TCanvas("c1", "c1", 640, 480)
173 topPad = r.TPad("topPad", "topPad", 0, 0.5, 1, 1)
174 btmPad = r.TPad("btmPad", "btmPad", 0, 0, 1, 0.5)
175 topPad.SetMargin(0.1, 0.1, 0, 0.149)
176 btmPad.SetMargin(0.1, 0.1, 0.303, 0)
177 topPad.SetNumber(1)
178 btmPad.SetNumber(2)
179 topPad.Draw()
180 btmPad.Draw()
181 isOdd = True
182 for key, hShift in KeyHisto.items():
183 hShift.SetStats(0)
184 for yn in range(hShift.GetNbinsY()):
185 norm = (hShift.ProjectionX("tmp", yn + 1, yn + 1, "")).GetMaximum()
186 if norm <= 0:
187 continue
188 for xn in range(hShift.GetNbinsX()):
189 hShift.SetBinContent(xn + 1, yn + 1, hShift.GetBinContent(xn + 1, yn + 1) / norm)
190 if isOdd:
191 topPad.cd()
192 hShift.Draw("colz")
193 else:
194 btmPad.cd()
195 hShift.Draw("colz")
196 c1.Print(outPDF, "Title:" + hShift.GetName())
197 isOdd = not isOdd
198 c1.Print(outPDF + "]")
199
200 dd = {}
201 runs = sorted(agreements[vu.time_algorithms[0]])
202 dd['run'] = sum([[i]*len(vu.names_sides) for i in runs], [])
203 dd['name'] = vu.names_sides*len(runs)
204 dd['side'] = [i[-1] for i in dd['name']]
205
206 for algo in vu.time_algorithms:
207 dd[f'agreement_{algo}'] = [agreements[algo][run][side] for run, side in zip(dd['run'], dd['name'])]
208 dd[f'precision_{algo}'] = [precisions[algo][run][side] for run, side in zip(dd['run'], dd['name'])]
209 dd[f'discrimination_{algo}'] = [discriminations[algo][run][side] for run, side in zip(dd['run'], dd['name'])]
210 dd[f'shift_agreement_{algo}'] = [shift_agreements[algo][run][side] for run, side in zip(dd['run'], dd['name'])]
211 dd[f'entries_onTracks_{algo}'] = [entries_onTracks[algo][run][side] for run, side in zip(dd['run'], dd['name'])]
212 dd[f'entries_eventT0_{algo}'] = [entries_eventT0[algo][run] for run, side in zip(dd['run'], dd['name'])]
213
214 # Make ROC plots
215 for run in runs:
216 plt.figure()
217 plt.plot(*roc_U['CoG6'][run], 'k-', label='CoG6 U')
218 plt.plot(*roc_V['CoG6'][run], 'k:', label='CoG6 V')
219 plt.plot(*roc_U['CoG3'][run], 'b-', label='CoG3 U')
220 plt.plot(*roc_V['CoG3'][run], 'b:', label='CoG3 V')
221 plt.plot(*roc_U['ELS3'][run], 'r-', label='ELS3 U')
222 plt.plot(*roc_V['ELS3'][run], 'r:', label='ELS3 V')
223 plt.legend(loc='lower left')
224 plt.xlabel('sgn efficiency')
225 plt.ylabel('bkg rejection')
226 plt.title(f'ROC run {run}')
227 plt.xlim((0, 1))
228 plt.ylim((0, 1))
229 plt.tight_layout()
230 plt.savefig(plots_per_run / f'ROC_{run}.pdf')
231 plt.close()
232
233 df = pd.DataFrame(dd)
234 df.to_pickle(output_dir / 'df.pkl')
235
236 # df = pd.read_pickle('df.pkl')
237
238 print('Making combined plots')
239
240 for algo in vu.time_algorithms:
241 plt.figure(figsize=(6.4*max(2, total_length/30), 4.8*2))
242 ax = sns.violinplot(x='run', y=f'agreement_{algo}', hue='side', data=df, split=True)
243 ax.set_ylim([-2, 2])
244 ax.xaxis.set_minor_locator(ticker.NullLocator())
245 plt.axhline(0, color='black', linestyle='--')
246 plt.axhline(0.5, color='black', linestyle=':')
247 plt.axhline(-0.5, color='black', linestyle=':')
248 plt.setp(ax.get_xticklabels(), rotation=90)
249 plt.tight_layout()
250 plt.savefig(output_dir / f'agreement_{algo}.pdf')
251 plt.close()
252
253 plt.figure(figsize=(6.4*max(2, total_length/30), 4.8*2))
254 ax = sns.violinplot(x='run', y=f'precision_{algo}', hue='side', data=df, split=True)
255 ax.set_ylim([0, 50])
256 ax.xaxis.set_minor_locator(ticker.NullLocator())
257 plt.axhline(10, color='black', linestyle=':')
258 plt.axhline(20, color='black', linestyle=':')
259 plt.setp(ax.get_xticklabels(), rotation=90)
260 plt.tight_layout()
261 plt.savefig(output_dir / f'precision_{algo}.pdf')
262 plt.close()
263
264 plt.figure(figsize=(6.4*max(2, total_length/30), 4.8*2))
265 ax = sns.violinplot(x='run', y=f'discrimination_{algo}', hue='side', data=df, split=True)
266 ax.set_ylim([0.5, 1])
267 ax.xaxis.set_minor_locator(ticker.NullLocator())
268 plt.axhline(0.8, color='black', linestyle=':')
269 plt.axhline(0.9, color='black', linestyle=':')
270 plt.setp(ax.get_xticklabels(), rotation=90)
271 plt.tight_layout()
272 plt.savefig(output_dir / f'discrimination_{algo}.pdf')
273 plt.close()
274
275 plt.figure(figsize=(6.4*max(2, total_length/30), 4.8*2))
276 ax = sns.violinplot(x='run', y=f'shift_agreement_{algo}', hue='side', data=df, split=True, cut=0)
277 ax.xaxis.set_minor_locator(ticker.NullLocator())
278 ax.set_ylim([0.0, 3.5])
279 plt.axhline(0, color='black', linestyle='--')
280 plt.axhline(0.5, color='black', linestyle=':')
281 plt.axhline(1.0, color='black', linestyle=':')
282 plt.axhline(2.0, color='black', linestyle=':')
283 plt.setp(ax.get_xticklabels(), rotation=90)
284 plt.tight_layout()
285 plt.savefig(output_dir / f'shift_agreement_{algo}.pdf')
286 plt.close()
287
288 plt.figure(figsize=(6.4*max(2, total_length/30), 4.8*2))
289 ax = sns.violinplot(x='run', y=f'entries_onTracks_{algo}', hue='side', data=df, split=True, cut=0)
290 ax.xaxis.set_minor_locator(ticker.NullLocator())
291 plt.setp(ax.get_xticklabels(), rotation=90)
292 plt.tight_layout()
293 plt.savefig(output_dir / f'entries_onTracks_{algo}.pdf')
294 plt.close()
295
296 plt.figure(figsize=(6.4*max(2, total_length/30), 4.8*2))
297 ax = sns.violinplot(x='run', y=f'entries_eventT0_{algo}', hue='side', data=df, split=True)
298 ax.xaxis.set_minor_locator(ticker.NullLocator())
299 plt.setp(ax.get_xticklabels(), rotation=90)
300 plt.tight_layout()
301 plt.savefig(output_dir / f'entries_eventT0_{algo}.pdf')
302 plt.close()
303
304
305if __name__ == '__main__':
306
307 import argparse
308 parser = argparse.ArgumentParser(description=__doc__,
309 formatter_class=argparse.RawTextHelpFormatter)
310
311 # b2val-prompt-run wants to pass to the script also input_data_path
312 # and requested_iov. As they are not required by this validation I just accept
313 # them together with calibration_results_dir and then ignore them
314 parser.add_argument('calibration_results_dir',
315 help='The directory that contains the collector outputs',
316 nargs='+')
317
318 parser.add_argument('-o', '--output_dir',
319 help='The directory where all the output will be saved',
320 default='SVDTimeValidation_output')
321 parser.add_argument('-l',
322 help='Make additional pdf with details cluster size vs shift',
323 action='store_true')
324 args = parser.parse_args()
325
326 run_validation(args.calibration_results_dir[0], output_dir=args.output_dir, shift_detailed=args.l)