Belle II Software  release-06-01-15
__init__.py
1 
11 """
12 Generate PXD background samples for background overlay on the fly.
13 """
14 import functools
15 import hashlib
16 import os.path
17 import pathlib
18 from itertools import product
19 from secrets import randbelow
20 from typing import Callable, Union
21 
22 import basf2
23 from ROOT.Belle2 import DataStore, PyStoreArray, PyStoreObj
24 from ROOT.Belle2 import DBAccessorBase, DBStoreEntry
25 from ROOT.Belle2 import VxdID, PXDDigit
26 
27 from .models import MODELS
28 from .models import _get_model_cls, _get_generate_func
29 
30 
37 VXDID_ARGS = tuple(
38  tuple(product([1], [ladder + 1 for ladder in range(8)], [1, 2]))
39  + tuple(product([2], [ladder + 1 for ladder in range(12)], [1, 2]))
40 )
41 
42 
43 
50 def _verify_model(model: str) -> str:
51  if not isinstance(model, str):
52  raise TypeError("expecting type str `model`")
53  elif model not in MODELS:
54  options = ", ".join(map(repr, MODELS))
55  raise ValueError(f"invalid `model`: {model!r} (options: {options}")
56  return model
57 
58 
59 
66 def _verify_checkpoint(checkpoint: Union[None, str, pathlib.Path]) -> str:
67  if not isinstance(checkpoint, (type(None), str, pathlib.Path)):
68  raise TypeError("expecting None or type str `checkpoint`")
69  if checkpoint is None:
70  return checkpoint
71  checkpoint = os.path.expandvars(str(checkpoint))
72  if not (os.path.exists(checkpoint) and os.path.isfile(checkpoint)):
73  raise ValueError(f"invalid `checkpoint`: {checkpoint!r}")
74  return checkpoint
75 
76 
77 
84 def _verify_seed(seed: Union[None, int]) -> Union[None, int]:
85  if not isinstance(seed, (type(None), int)):
86  raise TypeError("expecting None or type int `seed`")
87  if seed is None:
88  return seed
89  if not -(2 ** 63) <= seed < 2 ** 63:
90  raise ValueError(f"expecting -2^63 <= `seed` < 2^63 (got: {seed})")
91  return seed
92 
93 
94 
101 def _verify_nintra(nintra: int) -> int:
102  if not isinstance(nintra, int):
103  raise TypeError("expecting type int `nintra`")
104  elif not nintra > 0:
105  raise ValueError(f"expecting `nintra` > 0 (got: {nintra}")
106  return nintra
107 
108 
109 
116 def _verify_ninter(ninter: int) -> int:
117  if not isinstance(ninter, int):
118  raise TypeError("expecting type int `ninter`")
119  elif not ninter > 0:
120  raise ValueError(f"expecting `ninter` > 0 (got: {ninter}")
121  return ninter
122 
123 
124 
130 def _verify_globaltag(globaltag: str) -> str:
131  if not isinstance(globaltag, str):
132  raise TypeError("expecting type str `globaltag`")
133  return globaltag
134 
135 
136 
138 class PXDBackgroundGenerator(basf2.Module):
139  """Generates PXD background samples for background overlay on the fly.
140 
141  :param model: Name of the generator model to use - either "convnet" or "resnet",
142  defaults to "convnet"
143  :type model: str, optional
144 
145  :param checkpoint: Path to the checkpoint file with weights for the selected model,
146  defaults to None - use the default checkpoint from the conditions database
147  :type checkpoint: str, optional
148 
149  :param seed: Integer number in the interval :math:`[-2^{63}, 2^{63} - 1]`
150  used internally as the initial seed,
151  defaults to None - derive a deterministic seed from the
152  value returned by :py:func:`basf2.get_random_seed`
153  :type seed: int, optional
154 
155  :param nintra: Number of intra-op threads to be utilized for the generation,
156  defaults to 1
157  :type nintra: int, optional
158 
159  :param ninter: Number of inter-op threads to be utilized for the generation,
160  defaults to 1
161  :type ninter: int, optional
162 
163  :param globaltag: Global tag of the conditions database
164  providing the default checkpoints stored as payloads,
165  defaults to "PXDBackgroundGenerator"
166  :type globaltag: str, optional
167  """
168 
169 
192  def __init__(
193  self,
194  model: str = "convnet",
195  checkpoint: Union[None, str, pathlib.Path] = None,
196  seed: Union[None, int] = None,
197  nintra: int = 1,
198  ninter: int = 1,
199  globaltag: str = "PXDBackgroundGenerator",
200  ):
201  super().__init__()
202  # process `model`
203  self.modelmodel = _verify_model(model)
204 
205  # process `checkpoint`
206  self.checkpointcheckpoint = _verify_checkpoint(checkpoint)
207 
208  # process `seed`
209  self.seedseed = _verify_seed(seed)
210 
211  # process `nintra`
212  self.nintranintra = _verify_nintra(nintra)
213 
214  # process `ninter`
215  self.ninterninter = _verify_ninter(ninter)
216 
217  # process `globaltag`
218  self.globaltagglobaltag = _verify_globaltag(globaltag)
219 
220  # enable the specified global tag
221  basf2.conditions.append_globaltag(self.globaltagglobaltag)
222 
223 
225  def initialize(self):
226  try:
227  import torch
228  except ImportError as exc:
229  exc.msg = "please install PyTorch: `pip3 install torch==1.4.0`"
230  raise
231 
232  # set the number of inter-op CPU threads
233  torch.set_num_interop_threads(self.ninterninter)
234 
235  # set the number of intra-op CPU threads
236  torch.set_num_threads(self.nintranintra)
237 
238  # get the generation function for the selected model
239  self._generate_func_generate_func = _get_generate_func(self.modelmodel)
240 
241  # instantiate the generator model
242  self._generator_generator = _get_model_cls(self.modelmodel)()
243 
244  # initialize the model weights
245  checkpoint = self.checkpointcheckpoint
246  if self.checkpointcheckpoint is None:
247  # use the default checkpoint from the conditions database
248  payload = f"PXDBackgroundGenerator_{self.model}"
249  accessor = DBAccessorBase(DBStoreEntry.c_RawFile, payload, True)
250  checkpoint = accessor.getFilename()
251  self._generator_generator.load_state_dict(torch.load(checkpoint, map_location="cpu"))
252 
253  # set mode of operation to inference
254  self._generator_generator.eval()
255 
256  # disable the computation of gradients
257  for param in self._generator_generator.parameters():
258  param.requires_grad = False
259 
260  # initialize the seed
261  seed = self.seedseed
262  if seed is None:
263  # derive from the basf2 seed value
264  obj = basf2.get_random_seed()
265  func = hashlib.sha512()
266  func.update(bytes(str(obj), "utf-8"))
267  digest = func.digest()[:8]
268  seed = int.from_bytes(digest, "big", signed=True)
269  basf2.B2INFO(f"PXD background generator seed initialized to {seed}")
270  torch.manual_seed(seed)
271 
272  # instantiate objects for specifying distinct PXD modules
273  self._vxdids_vxdids = tuple(VxdID(arg) for arg in VXDID_ARGS)
274 
275  # get the name of the extension used by BGOverlayInput for background collections
276  bkginfo = PyStoreObj("BackgroundInfo", DataStore.c_Persistent)
277  if not bkginfo.isValid():
278  # no information about background overlay input available
279  basf2.B2ERROR("path must contain the BGOverlayInput module")
280  extension = bkginfo.getExtensionName()
281 
282  # register the PXD background digit collection - array - in the data store
283  self._pystorearray_pystorearray = PyStoreArray("PXDDigits", DataStore.c_DontWriteOut)
284  self._pystorearray_pystorearray.registerInDataStore(f"PXDDigits{extension}")
285 
286 
288  def event(self):
289  # get the low-level array accessor
290  digit_array = self._pystorearray_pystorearray.getPtr()
291 
292  # clear digits stored by BGOverlayInput
293  digit_array.Clear()
294 
295  # generate a batch of images - one image for each PXD module
296  x = self._generate_func_generate_func(self._generator_generator)
297 
298  # locate indices of pixels with non-zero values - pixel hits
299  nonzero = x.nonzero(as_tuple=True)
300  args = [indices.tolist() for indices in nonzero]
301  vals = x[nonzero].tolist()
302 
303  # store indices and pixel values into the data store as background digits
304  for n, (idx, ucell, vcell, charge) in enumerate(zip(*args, vals)):
305  # append a new default digit to expand the array
306  self._pystorearray_pystorearray.appendNew()
307  # modify the array to point to the correct digit
308  digit_array[n] = PXDDigit(self._vxdids_vxdids[idx], ucell, vcell, charge)
309 
310  # delete references to release memory
311  del x, nonzero, args, vals
312 
313 
316 
317 
320 
321 
325 
326 
329 
330 
333 
334 
338 
339 
342 
343 
347 
348 
351 
352 
355 
356 
357 
365 def inject_simulation(module: Union[None, basf2.Module] = None) -> Callable:
366  """Incorporate a module instance
367  into :py:func:`.add_simulation` after `!BGOverlayInput`.
368 
369  :param module: Module instance to be incorporated,
370  defaults to None - return unmodified function
371  :type module: :py:class:`basf2.Module`, optional
372 
373  :returns: Drop-in replacement function for :py:func:`.add_simulation`
374  """
375  from simulation import add_simulation
376 
377  if module is None:
378  # no modifications necessary
379  return add_simulation
380  elif not isinstance(module, basf2.Module):
381  raise TypeError("expecting None or type basf2.Module `module`")
382 
383  @functools.wraps(add_simulation)
384  def injected_simulation(path, *args, **kwargs):
385  # create a separate path with simulation modules
386  simulation_path = basf2.Path()
387  add_simulation(simulation_path, *args, **kwargs)
388 
389  # manually add the simulation modules to the given path
390  for simulation_module in simulation_path.modules():
391  # append the next module from the simulation path
392  path.add_module(simulation_module)
393 
394  if simulation_module.name() == "BGOverlayInput":
395  # incorporate the given module
396  path.add_module(module)
397 
398  return injected_simulation
Class for the PXD background generator module.
Definition: __init__.py:138
_generate_func
Generation function applied on the model instance to return an output that is transcoded into digits.
Definition: __init__.py:239
ninter
Number of inter-op threads utilized.
Definition: __init__.py:215
def __init__(self, str model="convnet", Union[None, str, pathlib.Path] checkpoint=None, Union[None, int] seed=None, int nintra=1, int ninter=1, str globaltag="PXDBackgroundGenerator")
Constructor for the PXD background generator module.
Definition: __init__.py:200
nintra
Number of intra-op threads utilized.
Definition: __init__.py:212
checkpoint
Path to the checkpoint file with the pre-trained model weights.
Definition: __init__.py:206
globaltag
Global tag of the conditions database providing the default checkpoints stored as payloads.
Definition: __init__.py:218
seed
Integer number in the interval set as the initial seed.
Definition: __init__.py:209
_vxdids
Sequence of identifier objects for each PXD module.
Definition: __init__.py:273
def initialize(self)
Method called before event processing to initialize the module.
Definition: __init__.py:225
_pystorearray
Accessor for PXD background digits in the data store.
Definition: __init__.py:283
def event(self)
Method called each time an event is processed.
Definition: __init__.py:288
double eval(const std::vector< double > &spl, const std::vector< double > &vals, double x)
Evaluate spline (zero order or first order) in point x.
Definition: tools.h:115
int _verify_nintra(int nintra)
Function to verify that nintra:
Definition: __init__.py:101
str _verify_model(str model)
Function to verify that model:
Definition: __init__.py:50
Callable inject_simulation(Union[None, basf2.Module] module=None)
Helper function to incorporate a module instance into add_simulation after BGOverlayInput.
Definition: __init__.py:365
int _verify_ninter(int ninter)
Function to verify that ninter:
Definition: __init__.py:116
Union[None, int] _verify_seed(Union[None, int] seed)
Function to verify that seed:
Definition: __init__.py:84
str _verify_checkpoint(Union[None, str, pathlib.Path] checkpoint)
Function to verify that checkpoint:
Definition: __init__.py:66
str _verify_globaltag(str globaltag)
Function to verify that globaltag:
Definition: __init__.py:130