Belle II Software  release-08-01-10
__init__.py
1 
11 """
12 Generate PXD background samples for background overlay on the fly.
13 """
14 import functools
15 import hashlib
16 import os.path
17 import pathlib
18 from itertools import product
19 from typing import Callable, Union
20 
21 import basf2
22 from ROOT.Belle2 import DataStore, PyStoreArray, PyStoreObj
23 from ROOT.Belle2 import DBAccessorBase, DBStoreEntry
24 from ROOT.Belle2 import VxdID, PXDDigit
25 
26 from .models import MODELS
27 from .models import _get_model_cls, _get_generate_func
28 
29 
36 VXDID_ARGS = tuple(
37  tuple(product([1], [ladder + 1 for ladder in range(8)], [1, 2]))
38  + tuple(product([2], [ladder + 1 for ladder in range(12)], [1, 2]))
39 )
40 
41 
42 
49 def _verify_model(model: str) -> str:
50  if not isinstance(model, str):
51  raise TypeError("expecting type str `model`")
52  elif model not in MODELS:
53  options = ", ".join(map(repr, MODELS))
54  raise ValueError(f"invalid `model`: {model!r} (options: {options}")
55  return model
56 
57 
58 
65 def _verify_checkpoint(checkpoint: Union[None, str, pathlib.Path]) -> str:
66  if not isinstance(checkpoint, (type(None), str, pathlib.Path)):
67  raise TypeError("expecting None or type str `checkpoint`")
68  if checkpoint is None:
69  return checkpoint
70  checkpoint = os.path.expandvars(str(checkpoint))
71  if not (os.path.exists(checkpoint) and os.path.isfile(checkpoint)):
72  raise ValueError(f"invalid `checkpoint`: {checkpoint!r}")
73  return checkpoint
74 
75 
76 
83 def _verify_seed(seed: Union[None, int]) -> Union[None, int]:
84  if not isinstance(seed, (type(None), int)):
85  raise TypeError("expecting None or type int `seed`")
86  if seed is None:
87  return seed
88  if not -(2 ** 63) <= seed < 2 ** 63:
89  raise ValueError(f"expecting -2^63 <= `seed` < 2^63 (got: {seed})")
90  return seed
91 
92 
93 
100 def _verify_nintra(nintra: int) -> int:
101  if not isinstance(nintra, int):
102  raise TypeError("expecting type int `nintra`")
103  elif not nintra > 0:
104  raise ValueError(f"expecting `nintra` > 0 (got: {nintra}")
105  return nintra
106 
107 
108 
115 def _verify_ninter(ninter: int) -> int:
116  if not isinstance(ninter, int):
117  raise TypeError("expecting type int `ninter`")
118  elif not ninter > 0:
119  raise ValueError(f"expecting `ninter` > 0 (got: {ninter}")
120  return ninter
121 
122 
123 
129 def _verify_globaltag(globaltag: str) -> str:
130  if not isinstance(globaltag, str):
131  raise TypeError("expecting type str `globaltag`")
132  return globaltag
133 
134 
135 
137 class PXDBackgroundGenerator(basf2.Module):
138  """Generates PXD background samples for background overlay on the fly.
139 
140  :param model: Name of the generator model to use - either "convnet" or "resnet",
141  defaults to "convnet" (optional)
142  :type model: str
143 
144  :param checkpoint: Path to the checkpoint file with weights for the selected model,
145  defaults to None - use the default checkpoint from the conditions database (optional)
146  :type checkpoint: str
147 
148  :param seed: Integer number in the interval :math:`[-2^{63}, 2^{63} - 1]`
149  used internally as the initial seed,
150  defaults to None - derive a deterministic seed from the
151  value returned by :py:func:`basf2.get_random_seed` (optional)
152  :type seed: int
153 
154  :param nintra: Number of intra-op threads to be utilized for the generation,
155  defaults to 1 (optional)
156  :type nintra: int
157 
158  :param ninter: Number of inter-op threads to be utilized for the generation,
159  defaults to 1 (optional)
160  :type ninter: int
161 
162  :param globaltag: Global tag of the conditions database
163  providing the default checkpoints stored as payloads,
164  defaults to "PXDBackgroundGenerator" (optional)
165  :type globaltag: str
166  """
167 
168 
191  def __init__(
192  self,
193  model: str = "convnet",
194  checkpoint: Union[None, str, pathlib.Path] = None,
195  seed: Union[None, int] = None,
196  nintra: int = 1,
197  ninter: int = 1,
198  globaltag: str = "PXDBackgroundGenerator",
199  ):
200  super().__init__()
201  # process `model`
202  self.modelmodel = _verify_model(model)
203 
204  # process `checkpoint`
205  self.checkpointcheckpoint = _verify_checkpoint(checkpoint)
206 
207  # process `seed`
208  self.seedseed = _verify_seed(seed)
209 
210  # process `nintra`
211  self.nintranintra = _verify_nintra(nintra)
212 
213  # process `ninter`
214  self.ninterninter = _verify_ninter(ninter)
215 
216  # process `globaltag`
217  self.globaltagglobaltag = _verify_globaltag(globaltag)
218 
219  # enable the specified global tag
220  basf2.conditions.append_globaltag(self.globaltagglobaltag)
221 
222 
224  def initialize(self):
225  try:
226  import torch
227  except ImportError as exc:
228  exc.msg = "please install PyTorch: `pip3 install torch==1.4.0`"
229  raise
230 
231  # set the number of inter-op CPU threads
232  torch.set_num_interop_threads(self.ninterninter)
233 
234  # set the number of intra-op CPU threads
235  torch.set_num_threads(self.nintranintra)
236 
237  # get the generation function for the selected model
238  self._generate_func_generate_func = _get_generate_func(self.modelmodel)
239 
240  # instantiate the generator model
241  self._generator_generator = _get_model_cls(self.modelmodel)()
242 
243  # initialize the model weights
244  checkpoint = self.checkpointcheckpoint
245  if self.checkpointcheckpoint is None:
246  # use the default checkpoint from the conditions database
247  payload = f"PXDBackgroundGenerator_{self.model}"
248  accessor = DBAccessorBase(DBStoreEntry.c_RawFile, payload, True)
249  checkpoint = accessor.getFilename()
250  self._generator_generator.load_state_dict(torch.load(checkpoint, map_location="cpu"))
251 
252  # set mode of operation to inference
253  self._generator_generator.eval()
254 
255  # disable the computation of gradients
256  for param in self._generator_generator.parameters():
257  param.requires_grad = False
258 
259  # initialize the seed
260  seed = self.seedseed
261  if seed is None:
262  # derive from the basf2 seed value
263  obj = basf2.get_random_seed()
264  func = hashlib.sha512()
265  func.update(bytes(str(obj), "utf-8"))
266  digest = func.digest()[:8]
267  seed = int.from_bytes(digest, "big", signed=True)
268  basf2.B2INFO(f"PXD background generator seed initialized to {seed}")
269  torch.manual_seed(seed)
270 
271  # instantiate objects for specifying distinct PXD modules
272  self._vxdids_vxdids = tuple(VxdID(arg) for arg in VXDID_ARGS)
273 
274  # get the name of the extension used by BGOverlayInput for background collections
275  bkginfo = PyStoreObj("BackgroundInfo", DataStore.c_Persistent)
276  if not bkginfo.isValid():
277  # no information about background overlay input available
278  basf2.B2ERROR("path must contain the BGOverlayInput module")
279  extension = bkginfo.getExtensionName()
280 
281  # register the PXD background digit collection - array - in the data store
282  self._pystorearray_pystorearray = PyStoreArray("PXDDigits", DataStore.c_DontWriteOut)
283  self._pystorearray_pystorearray.registerInDataStore(f"PXDDigits{extension}")
284 
285 
287  def event(self):
288  # get the low-level array accessor
289  digit_array = self._pystorearray_pystorearray.getPtr()
290 
291  # clear digits stored by BGOverlayInput
292  digit_array.Clear()
293 
294  # generate a batch of images - one image for each PXD module
295  x = self._generate_func_generate_func(self._generator_generator)
296 
297  # locate indices of pixels with non-zero values - pixel hits
298  nonzero = x.nonzero(as_tuple=True)
299  args = [indices.tolist() for indices in nonzero]
300  vals = x[nonzero].tolist()
301 
302  # store indices and pixel values into the data store as background digits
303  for n, (idx, ucell, vcell, charge) in enumerate(zip(*args, vals)):
304  # append a new default digit to expand the array
305  self._pystorearray_pystorearray.appendNew()
306  # modify the array to point to the correct digit
307  digit_array[n] = PXDDigit(self._vxdids_vxdids[idx], ucell, vcell, charge)
308 
309  # delete references to release memory
310  del x, nonzero, args, vals
311 
312 
315 
316 
319 
320 
324 
325 
328 
329 
332 
333 
337 
338 
341 
342 
346 
347 
350 
351 
354 
355 
356 
364 def inject_simulation(module: Union[None, basf2.Module] = None) -> Callable:
365  """Incorporate a module instance
366  into :py:func:`.add_simulation` after `!BGOverlayInput`.
367 
368  :param module: Module instance to be incorporated,
369  defaults to None - return unmodified function
370  :type module: :py:class:`basf2.Module`, optional
371 
372  :returns: Drop-in replacement function for :py:func:`.add_simulation`
373  """
374  from simulation import add_simulation
375 
376  if module is None:
377  # no modifications necessary
378  return add_simulation
379  elif not isinstance(module, basf2.Module):
380  raise TypeError("expecting None or type basf2.Module `module`")
381 
382  @functools.wraps(add_simulation)
383  def injected_simulation(path, *args, **kwargs):
384  # create a separate path with simulation modules
385  simulation_path = basf2.Path()
386  add_simulation(simulation_path, *args, **kwargs)
387 
388  # manually add the simulation modules to the given path
389  for simulation_module in simulation_path.modules():
390  # append the next module from the simulation path
391  path.add_module(simulation_module)
392 
393  if simulation_module.name() == "BGOverlayInput":
394  # incorporate the given module
395  path.add_module(module)
396 
397  return injected_simulation
Class for the PXD background generator module.
Definition: __init__.py:137
_generate_func
Generation function applied on the model instance to return an output that is transcoded into digits.
Definition: __init__.py:238
ninter
Number of inter-op threads utilized.
Definition: __init__.py:214
def __init__(self, str model="convnet", Union[None, str, pathlib.Path] checkpoint=None, Union[None, int] seed=None, int nintra=1, int ninter=1, str globaltag="PXDBackgroundGenerator")
Constructor for the PXD background generator module.
Definition: __init__.py:199
nintra
Number of intra-op threads utilized.
Definition: __init__.py:211
checkpoint
Path to the checkpoint file with the pre-trained model weights.
Definition: __init__.py:205
globaltag
Global tag of the conditions database providing the default checkpoints stored as payloads.
Definition: __init__.py:217
seed
Integer number in the interval set as the initial seed.
Definition: __init__.py:208
_vxdids
Sequence of identifier objects for each PXD module.
Definition: __init__.py:272
def initialize(self)
Method called before event processing to initialize the module.
Definition: __init__.py:224
_pystorearray
Accessor for PXD background digits in the data store.
Definition: __init__.py:282
def event(self)
Method called each time an event is processed.
Definition: __init__.py:287
double eval(const std::vector< double > &spl, const std::vector< double > &vals, double x)
Evaluate spline (zero order or first order) in point x.
Definition: tools.h:115
int _verify_nintra(int nintra)
Function to verify that nintra:
Definition: __init__.py:100
str _verify_model(str model)
Function to verify that model:
Definition: __init__.py:49
Callable inject_simulation(Union[None, basf2.Module] module=None)
Helper function to incorporate a module instance into add_simulation after BGOverlayInput.
Definition: __init__.py:364
int _verify_ninter(int ninter)
Function to verify that ninter:
Definition: __init__.py:115
Union[None, int] _verify_seed(Union[None, int] seed)
Function to verify that seed:
Definition: __init__.py:83
str _verify_checkpoint(Union[None, str, pathlib.Path] checkpoint)
Function to verify that checkpoint:
Definition: __init__.py:65
str _verify_globaltag(str globaltag)
Function to verify that globaltag:
Definition: __init__.py:129