Belle II Software  release-08-01-10
ChargedPidMVAModule.py
1 #!/usr/bin/env python3
2 
3 
10 
11 # Doxygen should skip this script
12 # @cond
13 
14 """
15 This steering file fills an NTuple with the ChargedPidMVA score
16 for charged particle identification. By default, global PID info is stored,
17 meaning one signal hypothesis is tested against all others.
18 Optionally, binary PID can be stored, by testing one (or more) pair of (S,B) mass hypotheses.
19 
20 Usage:
21 
22 basf2 [-i /PATH/TO/MDST/FILE.root] analysis/examples/PostMdstIdentification/ChargedPidMVAModule.py -- [OPTIONS]
23 
24 Input: *_mdst_*.root
25 Output: *_ntup_*.root
26 
27 """
28 
29 import argparse
30 import re
31 from modularAnalysis import getAnalysisGlobaltag
32 
33 
34 def argparser():
35 
36  parser = argparse.ArgumentParser(description=__doc__,
37  formatter_class=argparse.RawTextHelpFormatter)
38 
39  def sb_pair(arg):
40  try:
41  s, b = map(int, arg.split(','))
42  return s, b
43  except BaseException:
44  raise argparse.ArgumentTypeError("Option string must be of the form 'S,B'")
45 
46  parser.add_argument("--matchTruth",
47  action="store_true",
48  default=False,
49  help="Apply truth-matching on particles.")
50  parser.add_argument("--testHyposPDGCodePair",
51  type=sb_pair,
52  nargs='+',
53  default=(0, 0),
54  help="Option required in binary mode.\n"
55  "A list of pdgId pairs of the (S, B) charged stable particle mass hypotheses to test.\n"
56  "Pass a space-separated list of (>= 1) S,B pdgIds, e.g.:\n"
57  "'--testHyposPDGCodePair 11,211 13,211'")
58  parser.add_argument("--addECLOnly",
59  dest="add_ecl_only",
60  action="store_true",
61  default=False,
62  help="Apply the BDT also for the ECL-only training."
63  "This will result in a separate score branch in the ntuple.")
64  parser.add_argument("--chargeIndependent",
65  action="store_true",
66  default=False,
67  help="Use a BDT trained on a sample of inclusively charged particles.")
68  parser.add_argument("--global_tag_append",
69  type=str,
70  nargs="+",
71  default=[getAnalysisGlobaltag()],
72  help="List of names of conditions DB global tag(s) to append on top of GT replay.\n"
73  "NB: these GTs will have lowest priority over GT replay.\n"
74  "The order of the sequence passed determines the priority of the GTs, w/ the highest coming first.\n"
75  "Pass a space-separated list of names.")
76  parser.add_argument("--global_tag_prepend",
77  type=str,
78  nargs="+",
79  default=None,
80  help="List of names of conditions DB global tag(s) to prepend to GT replay.\n"
81  "NB: these GTs will have highest priority over GT replay.\n"
82  "The order of the sequence passed determines the priority of the GTs, w/ the highest coming first.\n"
83  "Pass a space-separated list of names.")
84  parser.add_argument("--append_testing_payloads",
85  type=str,
86  default=None,
87  help="Path to a text file with local test payloads.\n"
88  "NB: these will have higher priority than any payload in the GT(s).\n"
89  "Use ONLY for testing.")
90  parser.add_argument("-d", "--debug",
91  dest="debug",
92  action="store",
93  default=0,
94  type=int,
95  choices=list(range(11, 20)),
96  help="Run the ChargedPidMVA module in debug mode. Pass the desired DEBUG level integer.")
97 
98  return parser
99 
100 
101 if __name__ == '__main__':
102 
103  args = argparser().parse_args()
104 
105  import basf2
106  import modularAnalysis as ma
107  from ROOT import Belle2
108  import pdg
109 
110  for tag in args.global_tag_append:
111  basf2.conditions.append_globaltag(tag)
112  print(f"Appending GTs:\n{args.global_tag_append}")
113 
114  if args.global_tag_prepend:
115  for tag in reversed(args.global_tag_prepend):
116  basf2.conditions.prepend_globaltag(tag)
117  print(f"Prepending GTs:\n{args.global_tag_prepend}")
118 
119  if args.append_testing_payloads:
120  basf2.conditions.append_testing_payloads(args.append_testing_payloads)
121  print(f"Appending testing payloads (will have highest priority!)):\n{args.append_testing_payloads}")
122 
123  # ------------
124  # Create path.
125  # ------------
126 
127  path = basf2.create_path()
128 
129  # ----------
130  # Add input.
131  # ----------
132 
133  ma.inputMdst(filename=basf2.find_file("mdst14.root", "validation"),
134  path=path)
135 
136  # ---------------------------------------
137  # Load standard charged stable particles,
138  # and fill corresponding particle lists.
139  # ---------------------------------------
140 
141  std_charged = [
142  Belle2.Const.electron.getPDGCode(),
143  Belle2.Const.muon.getPDGCode(),
144  Belle2.Const.pion.getPDGCode(),
145  Belle2.Const.kaon.getPDGCode(),
146  Belle2.Const.proton.getPDGCode(),
147  Belle2.Const.deuteron.getPDGCode(),
148  ]
149 
150  plists = [(f"{pdg.to_name(pdgId)}:my_{pdg.to_name(pdgId)}", "") for pdgId in std_charged]
151  ma.fillParticleLists(plists, path=path)
152 
153  # --------------------------
154  # (Optional) truth matching.
155  # --------------------------
156 
157  if args.matchTruth:
158  for plistname, _ in plists:
159  ma.matchMCTruth(plistname, path=path)
160  ma.applyCuts(plistname, "isSignal == 1", path=path)
161 
162  # -------------------
163  # Global/Binary PID ?
164  # -------------------
165 
166  global_pid = (args.testHyposPDGCodePair == (0, 0))
167  binary_pid = not global_pid
168 
169  # ----------------------
170  # Apply charged Pid MVA.
171  # ----------------------
172 
173  if global_pid:
174  ma.applyChargedPidMVA(particleLists=[plistname for plistname, _ in plists],
175  path=path,
176  trainingMode=Belle2.ChargedPidMVAWeights.ChargedPidMVATrainingMode.c_Multiclass,
177  chargeIndependent=args.chargeIndependent)
178  if args.add_ecl_only:
179  ma.applyChargedPidMVA(particleLists=[plistname for plistname, _ in plists],
180  path=path,
181  trainingMode=Belle2.ChargedPidMVAWeights.ChargedPidMVATrainingMode.c_ECL_Multiclass)
182  elif binary_pid:
183  for s, b in args.testHyposPDGCodePair:
184  ma.applyChargedPidMVA(particleLists=[plistname for plistname, _ in plists],
185  path=path,
186  trainingMode=Belle2.ChargedPidMVAWeights.ChargedPidMVATrainingMode.c_Classification,
187  binaryHypoPDGCodes=(s, b),
188  chargeIndependent=args.chargeIndependent)
189  if args.add_ecl_only:
190  ma.applyChargedPidMVA(particleLists=[plistname for plistname, _ in plists],
191  path=path,
192  trainingMode=Belle2.ChargedPidMVAWeights.ChargedPidMVATrainingMode.c_ECL_Classification,
193  binaryHypoPDGCodes=(s, b))
194 
195  if args.debug:
196  for m in path.modules():
197  if "ChargedPidMVA" in m.name():
198  m.logging.log_level = basf2.LogLevel.DEBUG
199  m.logging.debug_level = args.debug
200 
201  # ---------------
202  # Make an NTuple.
203  # ---------------
204 
205  if global_pid:
206 
207  append = "_vs_".join(map(str, std_charged))
208 
209  variables = [f"pidChargedBDTScore({pdgId}, ALL)" for pdgId in std_charged]
210  if args.add_ecl_only:
211  variables += [f"pidChargedBDTScore({pdgId}, ECL)" for pdgId in std_charged]
212 
213  elif binary_pid:
214 
215  append = "__".join([f"{s}_vs_{b}" for s, b in args.testHyposPDGCodePair])
216 
217  variables = [f"pidPairChargedBDTScore({s}, {b}, ALL)" for s, b in args.testHyposPDGCodePair]
218  if args.add_ecl_only:
219  variables += [f"pidPairChargedBDTScore({s}, {b}, ECL)" for s, b in args.testHyposPDGCodePair]
220 
221  filename = f"chargedpid_ntuples__{append}.root"
222 
223  for plistname, _ in plists:
224 
225  # ROOT doesn't like non-alphanum chars...
226  treename = re.sub(r"[\W]+", "", plistname.split(':')[1])
227 
228  if global_pid:
229  ma.variablesToNtuple(decayString=plistname,
230  variables=variables,
231  treename=treename,
232  filename=filename,
233  path=path)
234  elif binary_pid:
235  ma.variablesToNtuple(decayString=plistname,
236  variables=variables,
237  treename=treename,
238  filename=filename,
239  path=path)
240 
241  # -----------------
242  # Monitor progress.
243  # -----------------
244 
245  progress = basf2.register_module("Progress")
246  path.add_module(progress)
247 
248  # ---------------
249  # Process events.
250  # ---------------
251 
252  # Start processing of modules.
253  basf2.process(path)
254 
255  # Print basf2 call statistics.
256  print(basf2.statistics)
257 
258 # @endcond