4 from pathlib
import Path
10 import matplotlib.pyplot
as plt
11 import matplotlib.ticker
as ticker
13 from prompt
import ValidationSettings
17 r.PyConfig.IgnoreCommandLineOptions =
True
20 plt.style.use(
"belle2")
23 settings = ValidationSettings(name=
"caf_svd_time",
29 def progress(count, total):
31 filled_len = int(round(bar_len * count / total))
32 percents = round(100 * count / total, 1)
33 bar =
'=' * filled_len +
'-' * (bar_len - filled_len)
34 sys.stdout.write(f
'[{bar}] {percents}%\r')
38 def run_validation(job_path, input_data_path=None, **kwargs):
39 '''job_path will be replaced with path/to/calibration_results
40 input_data_path will be replaced with path/to/data_path used for calibration
41 e.g. /group/belle2/dataprod/Data/PromptSkim/'''
43 collector_output_dir = Path(job_path) /
'SVDTimeValidation/0/collector_output/default/'
44 output_dir = Path(kwargs.get(
'output_dir',
'SVDTimeValidation_output'))
45 plots_per_run = output_dir /
'runs'
47 plots_per_run.mkdir(parents=
True, exist_ok=
True)
49 files = list(collector_output_dir.glob(
'**/CollectorOutput.root'))
51 agreements = {algo: {}
for algo
in vu.time_algorithms}
52 precisions = {algo: {}
for algo
in vu.time_algorithms}
53 discriminations = {algo: {}
for algo
in vu.time_algorithms}
54 entries_onTracks = {algo: {}
for algo
in vu.time_algorithms}
55 entries_eventT0 = {algo: {}
for algo
in vu.time_algorithms}
57 roc_U = {algo: {}
for algo
in vu.time_algorithms}
58 roc_V = {algo: {}
for algo
in vu.time_algorithms}
60 num_files = len(files)
61 print(f
'Looping over {num_files} files')
62 progress(0, num_files)
63 for count, in_file_name
in enumerate(files):
65 in_file = r.TFile(str(in_file_name))
67 for algo
in vu.time_algorithms:
69 histos, exp, run = vu.get_histos(in_file, algo)
72 print(f
'Skipping file {in_file_name} for {algo}')
77 entries_eventT0_ = histos[
'eventT0'].GetEntries()
78 if run
not in entries_eventT0[algo]
or entries_eventT0_ > entries_eventT0[algo][run]:
79 agreements[algo][run] = {key: vu.get_agreament(histos[
'eventT0'], h_diff)
80 for key, h_diff
in histos[
'diff'].items()}
81 precisions[algo][run] = {key: vu.get_precision(h_diff)
82 for key, h_diff
in histos[
'diff'].items()}
83 discriminations[algo][run] = {key: vu.get_roc_auc(histos[
'onTracks'][key], histos[
'offTracks'][key])
84 for key
in histos[
'onTracks']}
85 entries_onTracks[algo][run] = {key: val.GetEntries()
for key, val
in histos[
'onTracks'].items()}
86 entries_eventT0[algo][run] = entries_eventT0_
88 vu.make_combined_plot(
'*U', histos,
89 title=f
'exp {exp} run {run} U {algo}')
90 plt.savefig(plots_per_run / f
'{exp}_{run}_U_{algo}.pdf')
93 vu.make_combined_plot(
'*V', histos,
94 title=f
'exp {exp} run {run} V {algo}')
95 plt.savefig(plots_per_run / f
'{exp}_{run}_V_{algo}.pdf')
98 roc_U[algo][run] = vu.make_roc(vu.get_combined(histos[
'onTracks'],
'*U'),
99 vu.get_combined(histos[
'offTracks'],
'*U'))
100 roc_V[algo][run] = vu.make_roc(vu.get_combined(histos[
'onTracks'],
'*V'),
101 vu.get_combined(histos[
'offTracks'],
'*V'))
102 except AttributeError:
103 print(f
'Skipping file {in_file_name} for {algo}')
109 progress(count+1, num_files)
114 runs = sorted(agreements[vu.time_algorithms[0]])
115 dd[
'run'] = sum([[i]*len(vu.names_sides)
for i
in runs], [])
116 dd[
'name'] = vu.names_sides*len(runs)
117 dd[
'side'] = [i[-1]
for i
in dd[
'name']]
119 for algo
in vu.time_algorithms:
120 dd[f
'agreement_{algo}'] = [agreements[algo][run][side]
for run, side
in zip(dd[
'run'], dd[
'name'])]
121 dd[f
'precision_{algo}'] = [precisions[algo][run][side]
for run, side
in zip(dd[
'run'], dd[
'name'])]
122 dd[f
'discrimination_{algo}'] = [discriminations[algo][run][side]
for run, side
in zip(dd[
'run'], dd[
'name'])]
123 dd[f
'entries_onTracks_{algo}'] = [entries_onTracks[algo][run][side]
for run, side
in zip(dd[
'run'], dd[
'name'])]
124 dd[f
'entries_eventT0_{algo}'] = [entries_eventT0[algo][run]
for run, side
in zip(dd[
'run'], dd[
'name'])]
129 plt.plot(*roc_U[
'CoG6'][run],
'k-', label=
'CoG6 U')
130 plt.plot(*roc_V[
'CoG6'][run],
'k:', label=
'CoG6 V')
131 plt.plot(*roc_U[
'CoG3'][run],
'b-', label=
'CoG3 U')
132 plt.plot(*roc_V[
'CoG3'][run],
'b:', label=
'CoG3 V')
133 plt.plot(*roc_U[
'ELS3'][run],
'r-', label=
'ELS3 U')
134 plt.plot(*roc_V[
'ELS3'][run],
'r:', label=
'ELS3 V')
135 plt.legend(loc=
'lower left')
136 plt.xlabel(
'sgn efficiency')
137 plt.ylabel(
'bkg rejection')
138 plt.title(f
'ROC run {run}')
142 plt.savefig(plots_per_run / f
'ROC_{run}.pdf')
145 df = pd.DataFrame(dd)
146 df.to_pickle(output_dir /
'df.pkl')
150 print(
'Making combined plots')
152 for algo
in vu.time_algorithms:
153 plt.figure(figsize=(6.4*max(2, num_files/30), 4.8*2))
154 ax = sns.violinplot(x=
'run', y=f
'agreement_{algo}', hue=
'side', data=df, split=
True)
156 ax.xaxis.set_minor_locator(ticker.NullLocator())
157 plt.axhline(0, color=
'black', linestyle=
'--')
158 plt.axhline(0.5, color=
'black', linestyle=
':')
159 plt.axhline(-0.5, color=
'black', linestyle=
':')
160 plt.setp(ax.get_xticklabels(), rotation=90)
162 plt.savefig(output_dir / f
'agreement_{algo}.pdf')
165 plt.figure(figsize=(6.4*max(2, num_files/30), 4.8*2))
166 ax = sns.violinplot(x=
'run', y=f
'precision_{algo}', hue=
'side', data=df, split=
True)
168 ax.xaxis.set_minor_locator(ticker.NullLocator())
169 plt.axhline(10, color=
'black', linestyle=
':')
170 plt.axhline(20, color=
'black', linestyle=
':')
171 plt.setp(ax.get_xticklabels(), rotation=90)
173 plt.savefig(output_dir / f
'precision_{algo}.pdf')
176 plt.figure(figsize=(6.4*max(2, num_files/30), 4.8*2))
177 ax = sns.violinplot(x=
'run', y=f
'discrimination_{algo}', hue=
'side', data=df, split=
True)
178 ax.set_ylim([0.5, 1])
179 ax.xaxis.set_minor_locator(ticker.NullLocator())
180 plt.axhline(0.8, color=
'black', linestyle=
':')
181 plt.axhline(0.9, color=
'black', linestyle=
':')
182 plt.setp(ax.get_xticklabels(), rotation=90)
184 plt.savefig(output_dir / f
'discrimination_{algo}.pdf')
187 plt.figure(figsize=(6.4*max(2, num_files/30), 4.8*2))
188 ax = sns.violinplot(x=
'run', y=f
'entries_onTracks_{algo}', hue=
'side', data=df, split=
True, cut=0)
189 ax.xaxis.set_minor_locator(ticker.NullLocator())
190 plt.setp(ax.get_xticklabels(), rotation=90)
192 plt.savefig(output_dir / f
'entries_onTracks_{algo}.pdf')
195 plt.figure(figsize=(6.4*max(2, num_files/30), 4.8*2))
196 ax = sns.violinplot(x=
'run', y=f
'entries_eventT0_{algo}', hue=
'side', data=df, split=
True)
197 ax.xaxis.set_minor_locator(ticker.NullLocator())
198 plt.setp(ax.get_xticklabels(), rotation=90)
200 plt.savefig(output_dir / f
'entries_eventT0_{algo}.pdf')
204 if __name__ ==
'__main__':
207 parser = argparse.ArgumentParser(description=__doc__,
208 formatter_class=argparse.RawTextHelpFormatter)
213 parser.add_argument(
'calibration_results_dir',
214 help=
'The directory that contains the collector outputs',
217 parser.add_argument(
'-o',
'--output_dir',
218 help=
'The directory where all the output will be saved',
219 default=
'SVDTimeValidation_output')
220 args = parser.parse_args()
222 run_validation(args.calibration_results_dir[0], output_dir=args.output_dir)